Infer the right dtype for replay buffers. (#64)
This commit is contained in:
parent
8af7196a9a
commit
dd3e2130bb
@ -150,7 +150,8 @@ class ReplayBuffer(object):
|
||||
return
|
||||
if self.__dict__.get(name, None) is None:
|
||||
if isinstance(inst, np.ndarray):
|
||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
||||
self.__dict__[name] = np.zeros(
|
||||
(self._maxsize, *inst.shape), dtype=inst.dtype)
|
||||
elif isinstance(inst, (dict, Batch)):
|
||||
if name == 'info':
|
||||
self.__dict__[name] = np.array(
|
||||
@ -162,7 +163,8 @@ class ReplayBuffer(object):
|
||||
k_ = '_' + name + '@' + k
|
||||
self._add_to_buffer(k_, inst[k])
|
||||
else: # assume `inst` is a number
|
||||
self.__dict__[name] = np.zeros([self._maxsize])
|
||||
self.__dict__[name] = np.zeros(
|
||||
(self._maxsize,), dtype=np.asarray(inst).dtype)
|
||||
if isinstance(inst, np.ndarray) and \
|
||||
self.__dict__[name].shape[1:] != inst.shape:
|
||||
raise ValueError(
|
||||
|
Loading…
x
Reference in New Issue
Block a user