From 49f43e9f1f151b76690edf6f568ce5fc0bab1002 Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Wed, 24 Jun 2020 15:43:48 +0200 Subject: [PATCH] Fix Batch to numpy compatibility (#92) * Fix Batch to numpy compatibility. * Fix Batch unit tests. * Fix linter * Add Batch shape method. * Remove shape and add size. Enable to reserve keys using empty batch/list. * Fix linter and unit tests. * Batch init using list of Batch. * Add unit tests. * Fix Batch __len__. * Fix unit tests. * Fix slicing * Add missing slicing unit tests. Co-authored-by: Alexis Duburcq --- test/base/test_batch.py | 40 ++++++++++++-- tianshou/data/batch.py | 118 ++++++++++++++++++++++++++++++++++------ tianshou/data/buffer.py | 18 +++--- 3 files changed, 144 insertions(+), 32 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 0503cd5..93d8dc9 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -39,13 +39,32 @@ def test_batch(): 'c': np.zeros(1), 'd': Batch(e=np.array(3.0))}]) assert len(batch2) == 1 - assert list(batch2[1].keys()) == ['a'] - assert len(batch2[-2].a.d.keys()) == 0 - assert len(batch2[-1].keys()) > 0 - assert batch2[0][0].a.c == 0.0 + with pytest.raises(IndexError): + batch2[-2] + with pytest.raises(IndexError): + batch2[1] + with pytest.raises(TypeError): + batch2[0][0] assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) + batch2_from_list = Batch(list(batch2)) + batch2_from_comp = Batch([e for e in batch2]) + assert batch2_from_list.a.b == batch2.a.b + assert batch2_from_list.a.c == batch2.a.c + assert batch2_from_list.a.d.e == batch2.a.d.e + assert batch2_from_comp.a.b == batch2.a.b + assert batch2_from_comp.a.c == batch2.a.c + assert batch2_from_comp.a.d.e == batch2.a.d.e + for batch_slice in [ + batch2[slice(0, 1)], batch2[:1], batch2[0:]]: + assert batch_slice.a.b == batch2.a.b + assert batch_slice.a.c == batch2.a.c + assert batch_slice.a.d.e == batch2.a.d.e + batch2_sum = (batch2 + 1.0) * 2 + assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 + assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 + assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 def test_batch_over_batch(): @@ -146,6 +165,19 @@ def test_batch_from_to_numpy_without_copy(): assert c_mem_addr_new == c_mem_addr_orig +def test_batch_numpy_compatibility(): + batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), + b=Batch(), + c=np.array([5.0, 6.0])) + batch_mean = np.mean(batch) + assert isinstance(batch_mean, Batch) + assert sorted(batch_mean.keys()) == ['a', 'b', 'c'] + with pytest.raises(TypeError): + len(batch_mean) + assert np.all(batch_mean.a == np.mean(batch.a, axis=0)) + assert batch_mean.c == np.mean(batch.c, axis=0) + + if __name__ == '__main__': test_batch() test_batch_over_batch() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index a255b8d..935c332 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -3,6 +3,7 @@ import copy import pprint import warnings import numpy as np +from numbers import Number from typing import Any, List, Tuple, Union, Iterator, Optional # Disable pickle warning related to torch, since it has been removed @@ -75,15 +76,16 @@ class Batch: """ def __init__(self, - batch_dict: Optional[ - Union[dict, Tuple[dict], List[dict], np.ndarray]] = None, + batch_dict: Optional[Union[ + dict, 'Batch', Tuple[Union[dict, 'Batch']], + List[Union[dict, 'Batch']], np.ndarray]] = None, **kwargs) -> None: def _is_batch_set(data: Any) -> bool: if isinstance(data, (list, tuple)): - if len(data) > 0 and isinstance(data[0], dict): + if len(data) > 0 and isinstance(data[0], (dict, Batch)): return True elif isinstance(data, np.ndarray): - if isinstance(data.item(0), dict): + if isinstance(data.item(0), (dict, Batch)): return True return False @@ -102,7 +104,7 @@ class Batch: self.__dict__[k] = Batch.stack(v) else: self.__dict__[k] = list(v) - elif isinstance(batch_dict, dict): + elif isinstance(batch_dict, (dict, Batch)): for k, v in batch_dict.items(): if isinstance(v, dict) or _is_batch_set(v): self.__dict__[k] = Batch(v) @@ -141,22 +143,82 @@ class Batch: return _valid_bounds(length, min(index)) and \ _valid_bounds(length, max(index)) elif isinstance(index, slice): - return _valid_bounds(length, index.start) and \ - _valid_bounds(length, index.stop - 1) + if index.start is not None: + start_valid = _valid_bounds(length, index.start) + else: + start_valid = True + if index.stop is not None: + stop_valid = _valid_bounds(length, index.stop - 1) + else: + stop_valid = True + return start_valid and stop_valid if isinstance(index, str): return self.__getattr__(index) + + if not _valid_bounds(len(self), index): + raise IndexError( + f"Index {index} out of bounds for Batch of len {len(self)}.") else: b = Batch() for k, v in self.__dict__.items(): - if isinstance(v, Batch): - b.__dict__[k] = v[index] + if isinstance(v, Batch) and v.size == 0: + b.__dict__[k] = Batch() + elif isinstance(v, list) and len(v) == 0: + b.__dict__[k] = [] elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): if _valid_bounds(len(v), index): b.__dict__[k] = v[index] + else: + raise IndexError( + f"Index {index} out of bounds for {type(v)} of " + f"len {len(self)}.") return b + def __iadd__(self, val: Union['Batch', Number]): + if isinstance(val, Batch): + for k, r, v in zip(self.__dict__.keys(), + self.__dict__.values(), + val.__dict__.values()): + if r is None: + self.__dict__[k] = r + elif isinstance(r, list): + self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)] + else: + self.__dict__[k] = r + v + return self + elif isinstance(val, Number): + for k, r in zip(self.__dict__.keys(), self.__dict__.values()): + if r is None: + self.__dict__[k] = r + elif isinstance(r, list): + self.__dict__[k] = [r_ + val for r_ in r] + else: + self.__dict__[k] = r + val + return self + else: + raise TypeError("Only addition of Batch or number is supported.") + + def __add__(self, val: Union['Batch', Number]): + return copy.deepcopy(self).__iadd__(val) + + def __mul__(self, val: Number): + assert isinstance(val, Number), \ + "Only multiplication by a number is supported." + result = Batch() + for k, r in zip(self.__dict__.keys(), self.__dict__.values()): + result.__dict__[k] = r * val + return result + + def __truediv__(self, val: Number): + assert isinstance(val, Number), \ + "Only division by a number is supported." + result = Batch() + for k, r in zip(self.__dict__.keys(), self.__dict__.values()): + result.__dict__[k] = r / val + return result + def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" if key not in self.__dict__: @@ -167,12 +229,11 @@ class Batch: """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False - for k in sorted(self.__dict__.keys()): - if self.__dict__.get(k, None) is not None: - rpl = '\n' + ' ' * (6 + len(k)) - obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl) - s += f' {k}: {obj},\n' - flag = True + for k, v in self.__dict__.items(): + rpl = '\n' + ' ' * (6 + len(k)) + obj = pprint.pformat(v).replace('\n', rpl) + s += f' {k}: {obj},\n' + flag = True if flag: s += ')' else: @@ -296,10 +357,33 @@ class Batch: """Return len(self).""" r = [] for v in self.__dict__.values(): - if hasattr(v, '__len__') and (not isinstance( + if isinstance(v, Batch) and v.size == 0: + continue + elif isinstance(v, list) and len(v) == 0: + continue + elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): r.append(len(v)) - return max(r) if len(r) > 0 else 0 + else: + raise TypeError("Object of type 'Batch' has no len()") + if len(r) == 0: + raise TypeError("Object of type 'Batch' has no len()") + return min(r) + + @property + def size(self) -> int: + """Return self.size.""" + if len(self.__dict__) == 0: + return 0 + else: + r = [] + for v in self.__dict__.values(): + if isinstance(v, Batch): + r.append(v.size) + elif hasattr(v, '__len__') and (not isinstance( + v, (np.ndarray, torch.Tensor)) or v.ndim > 0): + r.append(len(v)) + return max(1, min(r) if len(r) > 0 else 0) def split(self, size: Optional[int] = None, shuffle: bool = True) -> Iterator['Batch']: diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index bd77d45..cd05393 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -5,7 +5,7 @@ from typing import Any, Tuple, Union, Optional from tianshou.data.batch import Batch -class ReplayBuffer(object): +class ReplayBuffer: """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. It stores basically 7 types of data, as mentioned in :class:`~tianshou.data.Batch`, based on @@ -96,7 +96,6 @@ class ReplayBuffer(object): def __init__(self, size: int, stack_num: Optional[int] = 0, ignore_obs_next: bool = False, **kwargs) -> None: - super().__init__() self._maxsize = size self._stack = stack_num self._save_s_ = not ignore_obs_next @@ -137,7 +136,7 @@ class ReplayBuffer(object): d[k_] = self.__dict__[k__] else: d[k_] = self.__getattr__(k__) - return Batch(**d) + return Batch(d) def _add_to_buffer(self, name: str, inst: Any) -> None: if inst is None: @@ -177,10 +176,7 @@ class ReplayBuffer(object): """Move the data from the given buffer to self.""" i = begin = buffer._index % len(buffer) while True: - self.add( - buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], - buffer.obs_next[i] if self._save_s_ else None, - buffer.info[i], buffer.policy[i]) + self.add(**buffer[i]) i = (i + 1) % len(buffer) if i == begin: break @@ -272,7 +268,7 @@ class ReplayBuffer(object): else: stack = [] many_keys = None - for i in range(stack_num): + for _ in range(stack_num): if many_keys is not None: for k_ in many_keys: k__ = '_' + key + '@' + k_ @@ -287,7 +283,7 @@ class ReplayBuffer(object): if many_keys is not None: for k in stack: stack[k] = np.stack(stack[k], axis=1) - stack = Batch(**stack) + stack = Batch(stack) else: stack = np.stack(stack, axis=1) return stack @@ -303,7 +299,7 @@ class ReplayBuffer(object): rew=self.rew[index], done=self.done[index], obs_next=self.get(index, 'obs_next'), - info=self.info[index], + info=self.get(index, 'info'), policy=self.get(index, 'policy'), ) @@ -440,7 +436,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): rew=self.rew[index], done=self.done[index], obs_next=self.get(index, 'obs_next'), - info=self.info[index], + info=self.get(index, 'info'), weight=self.weight[index], policy=self.get(index, 'policy'), )