Infer the right dtype for replay buffers. (#64)

This commit is contained in:
Alexis DUBURCQ 2020-05-29 16:27:03 +02:00 committed by GitHub
parent 8af7196a9a
commit dd3e2130bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -150,7 +150,8 @@ class ReplayBuffer(object):
return
if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray):
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
self.__dict__[name] = np.zeros(
(self._maxsize, *inst.shape), dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
if name == 'info':
self.__dict__[name] = np.array(
@ -162,7 +163,8 @@ class ReplayBuffer(object):
k_ = '_' + name + '@' + k
self._add_to_buffer(k_, inst[k])
else: # assume `inst` is a number
self.__dict__[name] = np.zeros([self._maxsize])
self.__dict__[name] = np.zeros(
(self._maxsize,), dtype=np.asarray(inst).dtype)
if isinstance(inst, np.ndarray) and \
self.__dict__[name].shape[1:] != inst.shape:
raise ValueError(