From 3086b5c31d0e9b63d97f82bf5e3165f8e78e4895 Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Thu, 25 Jun 2020 14:39:30 +0200 Subject: [PATCH] Buffer refactoring to support batch over batch reliably (#93) * Fix support of batch over batch for Buffer. * Do not use internal __dict__ attribute to store batch data since it breaks inheritance. * Various fixes. * Improve robustness of Batch/Buffer by avoiding direct attribute assignment. Buffer refactoring. * Add axis optional argument to Batch stack method. * Add item assignment to Batch class. * Fix list support for Buffer. * Convert list to np.array by default for efficiency. * Add missing unit test for Batch. Fix unit tests. * Batch item assignment is now robust to key order. * Do not use getattr/setattr explicity for simplicity. * More flexible __setitem__. * Fixes * Remove broacasting at Batch level since it is unreliable. * Forbid item assignement for inconsistent batches. * Implement broadcasting at Buffer level. * Add more unit test for Batch item assignment. Co-authored-by: Alexis Duburcq --- test/base/test_batch.py | 33 +++++-- tianshou/data/batch.py | 184 +++++++++++++++++++++++------------ tianshou/data/buffer.py | 211 +++++++++++++++++----------------------- 3 files changed, 233 insertions(+), 195 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 93d8dc9..dc08746 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -65,6 +65,15 @@ def test_batch(): assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 + batch3 = Batch(a={ + 'c': np.zeros(1), + 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) + batch3.a.d[0] = {'e': 4.0} + assert batch3.a.d.e[0] == 4.0 + batch3.a.d[0] = Batch(f=5.0) + assert batch3.a.d.f[0] == 5.0 + with pytest.raises(ValueError): + batch3.a.d[0] = Batch(f=5.0, g=0.0) def test_batch_over_batch(): @@ -93,16 +102,20 @@ def test_batch_over_batch(): def test_batch_cat_and_stack(): b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) - b_cat_out = Batch.cat((b1, b2)) - b_cat_in = copy.deepcopy(b1) - b_cat_in.cat_(b2) - assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) - assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) - assert isinstance(b_cat_in.a.d.e, np.ndarray) - assert b_cat_in.a.d.e.ndim == 1 - b_stack = Batch.stack((b1, b2)) - assert isinstance(b_stack.a.d.e, np.ndarray) - assert b_stack.a.d.e.ndim == 2 + b12_cat_out = Batch.cat((b1, b2)) + b12_cat_in = copy.deepcopy(b1) + b12_cat_in.cat_(b2) + assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) + assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) + assert isinstance(b12_cat_in.a.d.e, np.ndarray) + assert b12_cat_in.a.d.e.ndim == 1 + b12_stack = Batch.stack((b1, b2)) + assert isinstance(b12_stack.a.d.e, np.ndarray) + assert b12_stack.a.d.e.ndim == 2 + b3 = Batch(a=np.zeros((3, 4))) + b4 = Batch(a=np.ones((3, 4))) + b34_stack = Batch.stack((b3, b4), axis=1) + assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) def test_batch_over_batch_to_torch(): diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 935c332..21df4a4 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -75,6 +75,11 @@ class Batch: [11 22] [6 6] """ + def __new__(cls, *args, **kwargs) -> 'Batch': + self = super().__new__(cls) + self.__dict__['_data'] = {} + return self + def __init__(self, batch_dict: Optional[Union[ dict, 'Batch', Tuple[Union[dict, 'Batch']], @@ -95,21 +100,21 @@ class Batch: for k, v in zip(batch_dict[0].keys(), zip(*[e.values() for e in batch_dict])): if isinstance(v[0], dict) or _is_batch_set(v[0]): - self.__dict__[k] = Batch(v) + self[k] = Batch(v) elif isinstance(v[0], (np.generic, np.ndarray)): - self.__dict__[k] = np.stack(v, axis=0) + self[k] = np.stack(v, axis=0) elif isinstance(v[0], torch.Tensor): - self.__dict__[k] = torch.stack(v, dim=0) + self[k] = torch.stack(v, dim=0) elif isinstance(v[0], Batch): - self.__dict__[k] = Batch.stack(v) + self[k] = Batch.stack(v) else: - self.__dict__[k] = list(v) + self[k] = np.array(v) # fall back to np.object elif isinstance(batch_dict, (dict, Batch)): for k, v in batch_dict.items(): if isinstance(v, dict) or _is_batch_set(v): - self.__dict__[k] = Batch(v) + self[k] = Batch(v) else: - self.__dict__[k] = v + self[k] = v if len(kwargs) > 0: self.__init__(kwargs) @@ -140,8 +145,8 @@ class Batch: if isinstance(index, (int, np.integer)): return -length <= index and index < length elif isinstance(index, (list, np.ndarray)): - return _valid_bounds(length, min(index)) and \ - _valid_bounds(length, max(index)) + return _valid_bounds(length, np.min(index)) and \ + _valid_bounds(length, np.max(index)) elif isinstance(index, slice): if index.start is not None: start_valid = _valid_bounds(length, index.start) @@ -154,48 +159,75 @@ class Batch: return start_valid and stop_valid if isinstance(index, str): - return self.__getattr__(index) + return getattr(self, index) if not _valid_bounds(len(self), index): raise IndexError( f"Index {index} out of bounds for Batch of len {len(self)}.") else: b = Batch() - for k, v in self.__dict__.items(): + for k, v in self.items(): if isinstance(v, Batch) and v.size == 0: - b.__dict__[k] = Batch() - elif isinstance(v, list) and len(v) == 0: - b.__dict__[k] = [] + b[k] = Batch() elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): if _valid_bounds(len(v), index): - b.__dict__[k] = v[index] + if isinstance(index, (int, np.integer)) or \ + (isinstance(index, np.ndarray) and + index.ndim == 0) or \ + not isinstance(v, list): + b[k] = v[index] + else: + b[k] = [v[i] for i in index] else: raise IndexError( f"Index {index} out of bounds for {type(v)} of " f"len {len(self)}.") return b + def __setitem__(self, index: Union[ + str, slice, int, np.integer, np.ndarray, List[int]], + value: Any) -> None: + if isinstance(index, str): + return setattr(self, index, value) + if value is None: + value = Batch() + if not isinstance(value, (dict, Batch)): + raise TypeError("Batch does not supported value type " + f"{type(value)} for item assignment.") + if not set(value.keys()).issubset(self.keys()): + raise ValueError( + "Creating keys is not supported by item assignment.") + for key in self.keys(): + if isinstance(self[key], Batch): + default = Batch() + elif isinstance(self[key], np.ndarray) and \ + self[key].dtype == np.integer: + # Fallback for np.array of integer, + # since neither None or nan is supported. + default = 0 + else: + default = None + self[key][index] = value.get(key, default) + def __iadd__(self, val: Union['Batch', Number]): if isinstance(val, Batch): - for k, r, v in zip(self.__dict__.keys(), - self.__dict__.values(), - val.__dict__.values()): + for k, r, v in zip(self.keys(), self.values(), val.values()): if r is None: - self.__dict__[k] = r + self[k] = r elif isinstance(r, list): - self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)] + self[k] = [r_ + v_ for r_, v_ in zip(r, v)] else: - self.__dict__[k] = r + v + self[k] = r + v return self elif isinstance(val, Number): - for k, r in zip(self.__dict__.keys(), self.__dict__.values()): + for k, r in zip(self.keys(), self.values()): if r is None: - self.__dict__[k] = r + self[k] = r elif isinstance(r, list): - self.__dict__[k] = [r_ + val for r_ in r] + self[k] = [r_ + val for r_ in r] else: - self.__dict__[k] = r + val + self[k] = r + val return self else: raise TypeError("Only addition of Batch or number is supported.") @@ -206,30 +238,40 @@ class Batch: def __mul__(self, val: Number): assert isinstance(val, Number), \ "Only multiplication by a number is supported." - result = Batch() - for k, r in zip(self.__dict__.keys(), self.__dict__.values()): - result.__dict__[k] = r * val + result = self.__class__() + for k, r in zip(self.keys(), self.values()): + result[k] = r * val return result def __truediv__(self, val: Number): assert isinstance(val, Number), \ "Only division by a number is supported." - result = Batch() - for k, r in zip(self.__dict__.keys(), self.__dict__.values()): - result.__dict__[k] = r / val + result = self.__class__() + for k, r in zip(self.keys(), self.values()): + result[k] = r / val return result def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" - if key not in self.__dict__: - raise AttributeError(key) - return self.__dict__[key] + if key in self.__dict__.keys(): + return self.__dict__[key] + elif key in self._data.keys(): + return self._data[key] + raise AttributeError(key) + + def __setattr__(self, key, value): + if key in self._data.keys(): + self._data[key] = value + elif key in self.__dict__.keys(): + self.__dict__[key] = value + else: + self._data[key] = value def __repr__(self) -> str: """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False - for k, v in self.__dict__.items(): + for k, v in self.items(): rpl = '\n' + ' ' * (6 + len(k)) obj = pprint.pformat(v).replace('\n', rpl) s += f' {k}: {obj},\n' @@ -242,29 +284,29 @@ class Batch: def keys(self) -> List[str]: """Return self.keys().""" - return self.__dict__.keys() + return self._data.keys() def values(self) -> List[Any]: """Return self.values().""" - return self.__dict__.values() + return self._data.values() def items(self) -> Any: """Return self.items().""" - return self.__dict__.items() + return self._data.items() def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: """Return self[k] if k in self else d. d defaults to None.""" - if k in self.__dict__: - return self.__getattr__(k) + if k in self.keys(): + return self[k] return d def to_numpy(self) -> None: - """Change all torch.Tensor to numpy.ndarray. This is an inplace + """Change all torch.Tensor to numpy.ndarray. This is an in-place operation. """ - for k, v in self.__dict__.items(): + for k, v in self.items(): if isinstance(v, torch.Tensor): - self.__dict__[k] = v.detach().cpu().numpy() + self[k] = v.detach().cpu().numpy() elif isinstance(v, Batch): v.to_numpy() @@ -272,18 +314,18 @@ class Batch: dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu' ) -> None: - """Change all numpy.ndarray to torch.Tensor. This is an inplace + """Change all numpy.ndarray to torch.Tensor. This is an in-place operation. """ if not isinstance(device, torch.device): device = torch.device(device) - for k, v in self.__dict__.items(): + for k, v in self.items(): if isinstance(v, (np.generic, np.ndarray)): v = torch.from_numpy(v).to(device) if dtype is not None: v = v.type(dtype) - self.__dict__[k] = v + self[k] = v if isinstance(v, torch.Tensor): if dtype is not None and v.dtype != dtype: must_update_tensor = True @@ -297,7 +339,7 @@ class Batch: if must_update_tensor: if dtype is not None: v = v.type(dtype) - self.__dict__[k] = v.to(device) + self[k] = v.to(device) elif isinstance(v, Batch): v.to_torch(dtype, device) @@ -312,51 +354,67 @@ class Batch: """ assert isinstance(batch, Batch), \ 'Only Batch is allowed to be concatenated in-place!' - for k, v in batch.__dict__.items(): + for k, v in batch.items(): if v is None: continue - if not hasattr(self, k) or self.__dict__[k] is None: - self.__dict__[k] = copy.deepcopy(v) + if not hasattr(self, k) or self[k] is None: + self[k] = copy.deepcopy(v) elif isinstance(v, np.ndarray) and v.ndim > 0: - self.__dict__[k] = np.concatenate([self.__dict__[k], v]) + self[k] = np.concatenate([self[k], v]) elif isinstance(v, torch.Tensor): - self.__dict__[k] = torch.cat([self.__dict__[k], v]) + self[k] = torch.cat([self[k], v]) elif isinstance(v, list): - self.__dict__[k] += copy.deepcopy(v) + self[k] = self[k] + copy.deepcopy(v) elif isinstance(v, Batch): - self.__dict__[k].cat_(v) + self[k].cat_(v) else: s = 'No support for method "cat" with type '\ f'{type(v)} in class Batch.' raise TypeError(s) - @staticmethod - def cat(batches: List['Batch']) -> None: + @classmethod + def cat(cls, batches: List['Batch']) -> 'Batch': """Concatenate a :class:`~tianshou.data.Batch` object into a single new batch. """ assert isinstance(batches, (tuple, list)), \ 'Only list of Batch instances is allowed to be '\ 'concatenated out-of-place!' - batch = Batch() + batch = cls() for batch_ in batches: batch.cat_(batch_) return batch - @staticmethod - def stack(batches: List['Batch']): + @classmethod + def stack(cls, batches: List['Batch'], axis: int = 0) -> 'Batch': """Stack a :class:`~tianshou.data.Batch` object into a single new batch. """ assert isinstance(batches, (tuple, list)), \ 'Only list of Batch instances is allowed to be '\ 'stacked out-of-place!' - return Batch(np.array([batch.__dict__ for batch in batches])) + if axis == 0: + return cls(batches) + else: + batch = Batch() + for k, v in zip(batches[0].keys(), + zip(*[e.values() for e in batches])): + if isinstance(v[0], (np.generic, np.ndarray, list)): + batch[k] = np.stack(v, axis) + elif isinstance(v[0], torch.Tensor): + batch[k] = torch.stack(v, axis) + elif isinstance(v[0], Batch): + batch[k] = Batch.stack(v, axis) + else: + s = 'No support for method "stack" with type '\ + f'{type(v[0])} in class Batch and axis != 0.' + raise TypeError(s) + return batch def __len__(self) -> int: """Return len(self).""" r = [] - for v in self.__dict__.values(): + for v in self.values(): if isinstance(v, Batch) and v.size == 0: continue elif isinstance(v, list) and len(v) == 0: @@ -373,11 +431,11 @@ class Batch: @property def size(self) -> int: """Return self.size.""" - if len(self.__dict__) == 0: + if len(self.keys()) == 0: return 0 else: r = [] - for v in self.__dict__.values(): + for v in self.values(): if isinstance(v, Batch): r.append(v.size) elif hasattr(v, '__len__') and (not isinstance( diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 962dc7e..d1ac36d 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,11 +1,11 @@ -import pprint import numpy as np +from numbers import Number from typing import Any, Tuple, Union, Optional -from tianshou.data.batch import Batch +from .batch import Batch -class ReplayBuffer: +class ReplayBuffer(Batch): """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. It stores basically 7 types of data, as mentioned in :class:`~tianshou.data.Batch`, based on @@ -96,81 +96,47 @@ class ReplayBuffer: def __init__(self, size: int, stack_num: Optional[int] = 0, ignore_obs_next: bool = False, **kwargs) -> None: - self._maxsize = size - self._stack = stack_num - self._save_s_ = not ignore_obs_next - self._meta = {} + super().__init__() + self.__dict__['_maxsize'] = size + self.__dict__['_stack'] = stack_num + self.__dict__['_save_s_'] = not ignore_obs_next + self.__dict__['_index'] = 0 + self.__dict__['_size'] = 0 self.reset() def __len__(self) -> int: """Return len(self).""" return self._size - def __repr__(self) -> str: - """Return str(self).""" - s = self.__class__.__name__ + '(\n' - flag = False - for k in sorted(list(self.__dict__) + list(self._meta)): - if k[0] != '_' and (self.__dict__.get(k, None) is not None or - k in self._meta): - rpl = '\n' + ' ' * (6 + len(k)) - obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl) - s += f' {k}: {obj},\n' - flag = True - if flag: - s += ')' - else: - s = self.__class__.__name__ + '()' - return s - - def __getattr__(self, key: str) -> Union[Batch, np.ndarray]: - """Return self.key""" - if key not in self._meta: - if key not in self.__dict__: - raise AttributeError(key) - return self.__dict__[key] - d = {} - for k_ in self._meta[key]: - k__ = '_' + key + '@' + k_ - if k__ in self.__dict__: - d[k_] = self.__dict__[k__] - else: - d[k_] = self.__getattr__(k__) - return Batch(d) - def _add_to_buffer(self, name: str, inst: Any) -> None: - if inst is None: - if getattr(self, name, None) is None: - self.__dict__[name] = None - return - if name in self._meta: - for k in inst.keys(): - self._add_to_buffer('_' + name + '@' + k, inst[k]) - return - if self.__dict__.get(name, None) is None: + def _create_value(inst: Any) -> Union['Batch', np.ndarray]: if isinstance(inst, np.ndarray): - self.__dict__[name] = np.zeros( + return np.zeros( (self._maxsize, *inst.shape), dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): - if self._meta.get(name, None) is None: - self._meta[name] = list(inst.keys()) - for k in inst.keys(): - k_ = '_' + name + '@' + k - self._add_to_buffer(k_, inst[k]) - elif np.isscalar(inst): - self.__dict__[name] = np.zeros( + return Batch([Batch(inst) for _ in range(self._maxsize)]) + elif isinstance(inst, (np.generic, Number)): + return np.zeros( (self._maxsize,), dtype=np.asarray(inst).dtype) else: # fall back to np.object - self.__dict__[name] = np.array( - [None for _ in range(self._maxsize)]) + return np.array([None for _ in range(self._maxsize)]) + + if inst is None: + inst = Batch() + if name not in self.keys(): + self[name] = _create_value(inst) if isinstance(inst, np.ndarray) and \ - self.__dict__[name].shape[1:] != inst.shape: + self[name].shape[1:] != inst.shape: raise ValueError( "Cannot add data to a buffer with different shape, " - f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, " - f"given shape: {inst.shape}.") - if name not in self._meta: - self.__dict__[name][self._index] = inst + f"key: {name}, expect shape: {self[name].shape[1:]}" + f", given shape: {inst.shape}.") + if isinstance(self[name], Batch): + field_keys = self[name].keys() + for key, val in inst.items(): + if key not in field_keys: + self[name][key] = _create_value(val) + self[name][self._index] = inst def update(self, buffer: 'ReplayBuffer') -> None: """Move the data from the given buffer to self.""" @@ -209,7 +175,8 @@ class ReplayBuffer: def reset(self) -> None: """Clear all the data in replay buffer.""" - self._index = self._size = 0 + self._index = 0 + self._size = 0 def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with size equal to batch_size. \ @@ -226,7 +193,7 @@ class ReplayBuffer: ]) return self[indice], indice - def get(self, indice: Union[slice, np.ndarray], key: str, + def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is @@ -234,20 +201,16 @@ class ReplayBuffer: """ if stack_num is None: stack_num = self._stack - if not isinstance(indice, np.ndarray): - if np.isscalar(indice): - indice = np.array(indice) - elif isinstance(indice, slice): - indice = np.arange( - 0 if indice.start is None - else self._size - indice.start if indice.start < 0 - else indice.start, - self._size if indice.stop is None - else self._size - indice.stop if indice.stop < 0 - else indice.stop, - 1 if indice.step is None else indice.step) - else: - indice = np.array(indice) + if isinstance(indice, slice): + indice = np.arange( + 0 if indice.start is None + else self._size - indice.start if indice.start < 0 + else indice.start, + self._size if indice.stop is None + else self._size - indice.stop if indice.stop < 0 + else indice.stop, + 1 if indice.step is None else indice.step) + indice = np.array(indice, copy=True) # set last frame done to True last_index = (self._index - 1 + self._size) % self._size last_done, self.done[last_index] = self.done[last_index], True @@ -257,49 +220,51 @@ class ReplayBuffer: key = 'obs' if stack_num == 0: self.done[last_index] = last_done - if key in self._meta: - return {k: self.__dict__['_' + key + '@' + k][indice] - for k in self._meta[key]} + val = self[key] + if isinstance(val, Batch) and val.size == 0: + return val else: - return self.__dict__[key][indice] - if key in self._meta: - many_keys = self._meta[key] - stack = {k: [] for k in self._meta[key]} + if isinstance(indice, (int, np.integer)) or \ + (isinstance(indice, np.ndarray) and + indice.ndim == 0) or not isinstance(val, list): + return val[indice] + else: + return [val[i] for i in indice] else: - stack = [] - many_keys = None - for _ in range(stack_num): - if many_keys is not None: - for k_ in many_keys: - k__ = '_' + key + '@' + k_ - stack[k_] = [self.__dict__[k__][indice]] + stack[k_] + val = self[key] + if not isinstance(val, Batch) or val.size > 0: + stack = [] + for _ in range(stack_num): + stack = [val[indice]] + stack + pre_indice = np.asarray(indice - 1) + pre_indice[pre_indice == -1] = self._size - 1 + indice = np.asarray( + pre_indice + self.done[pre_indice].astype(np.int)) + indice[indice == self._size] = 0 + if isinstance(stack[0], Batch): + stack = Batch.stack(stack, axis=indice.ndim) + else: + stack = np.stack(stack, axis=indice.ndim) else: - stack = [self.__dict__[key][indice]] + stack - pre_indice = indice - 1 - pre_indice[pre_indice == -1] = self._size - 1 - indice = pre_indice + self.done[pre_indice].astype(np.int) - indice[indice == self._size] = 0 - self.done[last_index] = last_done - if many_keys is not None: - for k in stack: - stack[k] = np.stack(stack[k], axis=1) - stack = Batch(stack) - else: - stack = np.stack(stack, axis=1) - return stack + stack = Batch() + self.done[last_index] = last_done + return stack - def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch: + def __getitem__(self, index: Union[ + slice, int, np.integer, np.ndarray]) -> Batch: """Return a data batch: self[index]. If stack_num is set to be > 0, return the stacked obs and obs_next with shape [batch, len, ...]. """ + if isinstance(index, str): + return getattr(self, index) return Batch( obs=self.get(index, 'obs'), - act=self.act[index], + act=self.get(index, 'act', stack_num=0), # act_=self.get(index, 'act'), # stacked action, for RNN - rew=self.rew[index], - done=self.done[index], + rew=self.get(index, 'rew', stack_num=0), + done=self.get(index, 'done', stack_num=0), obs_next=self.get(index, 'obs_next'), - info=self.get(index, 'info'), + info=self.get(index, 'info', stack_num=0), policy=self.get(index, 'policy'), ) @@ -323,15 +288,15 @@ class ListReplayBuffer(ReplayBuffer): inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: if inst is None: return - if self.__dict__.get(name, None) is None: - self.__dict__[name] = [] - self.__dict__[name].append(inst) + if self._data.get(name, None) is None: + self._data[name] = [] + self._data[name].append(inst) def reset(self) -> None: self._index = self._size = 0 - for k in list(self.__dict__): - if isinstance(self.__dict__[k], list): - self.__dict__[k] = [] + for k in list(self._data): + if isinstance(self._data[k], list): + self._data[k] = [] class PrioritizedReplayBuffer(ReplayBuffer): @@ -449,16 +414,18 @@ class PrioritizedReplayBuffer(ReplayBuffer): - self.weight[indice].sum() self.weight[indice] = np.power(np.abs(new_weight), self._alpha) - def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch: + def __getitem__(self, index: Union[str, slice, np.ndarray]) -> Batch: + if isinstance(index, str): + return getattr(self, index) return Batch( obs=self.get(index, 'obs'), - act=self.act[index], + act=self.get(index, 'act', stack_num=0), # act_=self.get(index, 'act'), # stacked action, for RNN - rew=self.rew[index], - done=self.done[index], + rew=self.get(index, 'rew', stack_num=0), + done=self.get(index, 'done', stack_num=0), obs_next=self.get(index, 'obs_next'), info=self.get(index, 'info'), - weight=self.weight[index], + weight=self.get(index, 'weight', stack_num=0), policy=self.get(index, 'policy'), )