From dd3e2130bb8fce5ac636a6d686a20b31584ea336 Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Fri, 29 May 2020 16:27:03 +0200 Subject: [PATCH] Infer the right dtype for replay buffers. (#64) --- tianshou/data/buffer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 181a6b7..0b03582 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -150,7 +150,8 @@ class ReplayBuffer(object): return if self.__dict__.get(name, None) is None: if isinstance(inst, np.ndarray): - self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) + self.__dict__[name] = np.zeros( + (self._maxsize, *inst.shape), dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): if name == 'info': self.__dict__[name] = np.array( @@ -162,7 +163,8 @@ class ReplayBuffer(object): k_ = '_' + name + '@' + k self._add_to_buffer(k_, inst[k]) else: # assume `inst` is a number - self.__dict__[name] = np.zeros([self._maxsize]) + self.__dict__[name] = np.zeros( + (self._maxsize,), dtype=np.asarray(inst).dtype) if isinstance(inst, np.ndarray) and \ self.__dict__[name].shape[1:] != inst.shape: raise ValueError(