Infer the right dtype for replay buffers. (#64)

This commit is contained in:
Alexis DUBURCQ 2020-05-29 16:27:03 +02:00 committed by GitHub
parent 8af7196a9a
commit dd3e2130bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -150,7 +150,8 @@ class ReplayBuffer(object):
return return
if self.__dict__.get(name, None) is None: if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray): if isinstance(inst, np.ndarray):
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) self.__dict__[name] = np.zeros(
(self._maxsize, *inst.shape), dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)): elif isinstance(inst, (dict, Batch)):
if name == 'info': if name == 'info':
self.__dict__[name] = np.array( self.__dict__[name] = np.array(
@ -162,7 +163,8 @@ class ReplayBuffer(object):
k_ = '_' + name + '@' + k k_ = '_' + name + '@' + k
self._add_to_buffer(k_, inst[k]) self._add_to_buffer(k_, inst[k])
else: # assume `inst` is a number else: # assume `inst` is a number
self.__dict__[name] = np.zeros([self._maxsize]) self.__dict__[name] = np.zeros(
(self._maxsize,), dtype=np.asarray(inst).dtype)
if isinstance(inst, np.ndarray) and \ if isinstance(inst, np.ndarray) and \
self.__dict__[name].shape[1:] != inst.shape: self.__dict__[name].shape[1:] != inst.shape:
raise ValueError( raise ValueError(