Infer the right dtype for replay buffers. (#64)
This commit is contained in:
parent
8af7196a9a
commit
dd3e2130bb
@ -150,7 +150,8 @@ class ReplayBuffer(object):
|
|||||||
return
|
return
|
||||||
if self.__dict__.get(name, None) is None:
|
if self.__dict__.get(name, None) is None:
|
||||||
if isinstance(inst, np.ndarray):
|
if isinstance(inst, np.ndarray):
|
||||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
self.__dict__[name] = np.zeros(
|
||||||
|
(self._maxsize, *inst.shape), dtype=inst.dtype)
|
||||||
elif isinstance(inst, (dict, Batch)):
|
elif isinstance(inst, (dict, Batch)):
|
||||||
if name == 'info':
|
if name == 'info':
|
||||||
self.__dict__[name] = np.array(
|
self.__dict__[name] = np.array(
|
||||||
@ -162,7 +163,8 @@ class ReplayBuffer(object):
|
|||||||
k_ = '_' + name + '@' + k
|
k_ = '_' + name + '@' + k
|
||||||
self._add_to_buffer(k_, inst[k])
|
self._add_to_buffer(k_, inst[k])
|
||||||
else: # assume `inst` is a number
|
else: # assume `inst` is a number
|
||||||
self.__dict__[name] = np.zeros([self._maxsize])
|
self.__dict__[name] = np.zeros(
|
||||||
|
(self._maxsize,), dtype=np.asarray(inst).dtype)
|
||||||
if isinstance(inst, np.ndarray) and \
|
if isinstance(inst, np.ndarray) and \
|
||||||
self.__dict__[name].shape[1:] != inst.shape:
|
self.__dict__[name].shape[1:] != inst.shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user