diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 1a0ef43..2a670ab 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -569,10 +569,8 @@ class Batch: else: # scalar value warnings.warn('You are calling Batch.empty on a NumPy scalar, ' 'which may cause undefined behaviors.') - if isinstance(v, (np.generic, Number)): - self.__dict__[k] *= 0 - if np.isnan(self.__dict__[k]): - self.__dict__[k] = 0 + if isinstance(v, (np.number, np.bool_, Number)): + self.__dict__[k] = v.__class__(0) else: self.__dict__[k] = None return self