diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index f60a472..335d7e8 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -125,7 +125,7 @@ class BasePolicy(ABC, nn.Module): :return: a Batch. The result will be stored in batch.returns. """ if v_s_ is None: - v_s_ = np.zeros_like(batch.rew) + v_s_ = batch.rew * 0. else: if not isinstance(v_s_, np.ndarray): v_s_ = np.array(v_s_, np.float)