From ec270759ab3a060946f7c0ab3c482409eda980b4 Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Tue, 23 Jun 2020 16:50:59 +0200 Subject: [PATCH] Batch refactoring (#87) * Enable to stack Batch instances. Add Batch cat static method. Rename cat in cat_ since inplace. * Properly handle Batch init using np.array of dict. * WIP * Get rid of metadata. * Update UT. Replace cat by cat_ everywhere. * Do not sort Batch keys anymore for efficiency. Add items method. * Fix cat copy issue. * Add unit test to chack cat and stack methods. * Remove used import. * Fix linter issues. * Fix unit tests. Co-authored-by: Alexis Duburcq --- test/base/test_batch.py | 32 +++++++-- test/base/test_collector.py | 4 +- tianshou/data/batch.py | 128 +++++++++++++++++++++--------------- tianshou/data/buffer.py | 2 +- tianshou/data/collector.py | 2 +- 5 files changed, 107 insertions(+), 61 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 0d30a0a..be88a5a 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -1,4 +1,5 @@ import torch +import copy import pickle import pytest import numpy as np @@ -11,7 +12,7 @@ def test_batch(): assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] - batch.cat(batch) + batch.cat_(batch) assert batch.obs == [1, 1] assert batch.np.shape == (6, 4) assert batch[0].obs == batch[1].obs @@ -25,27 +26,48 @@ def test_batch(): with pytest.raises(AttributeError): b.obs print(batch) + batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} + batch_item = Batch({'a': [batch_dict]})[0] + assert isinstance(batch_item.a.b, np.ndarray) + assert batch_item.a.b == batch_dict['b'] + assert isinstance(batch_item.a.c, float) + assert batch_item.a.c == batch_dict['c'] + assert isinstance(batch_item.a.d, torch.Tensor) + assert batch_item.a.d == batch_dict['d'] def test_batch_over_batch(): batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) - batch2 = Batch(c=[6, 7, 8], b=batch) + batch2 = Batch({'c': [6, 7, 8], 'b': batch}) batch2.b.b[-1] = 0 print(batch2) - assert batch2.values()[-1] == batch2.c + for k, v in batch2.items(): + assert batch2[k] == v assert batch2[-1].b.b == 0 - batch2.cat(Batch(c=[6, 7, 8], b=batch)) + batch2.cat_(Batch(c=[6, 7, 8], b=batch)) assert batch2.c == [6, 7, 8, 6, 7, 8] assert batch2.b.a == [3, 4, 5, 3, 4, 5] assert batch2.b.b == [4, 5, 0, 4, 5, 0] d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) - batch3.cat(Batch(c=[6, 7, 8], b=d)) + batch3.cat_(Batch(c=[6, 7, 8], b=d)) assert batch3.c == [6, 7, 8, 6, 7, 8] assert batch3.b.a == [3, 4, 5, 3, 4, 5] assert batch3.b.b == [4, 5, 6, 4, 5, 6] +def test_batch_cat_and_stack(): + b1 = Batch(a=[{'b': np.array([1.0]), 'd': Batch(e=np.array([3.0]))}]) + b2 = Batch(a=[{'b': np.array([4.0]), 'd': Batch(e=np.array([6.0]))}]) + b_cat_out = Batch.cat((b1, b2)) + b_cat_in = copy.deepcopy(b1) + b_cat_in.cat_(b2) + assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) + assert b_cat_in.a.d.e.ndim == 2 + b_stack = Batch.stack((b1, b2)) + assert b_stack.a.d.e.ndim == 3 + + def test_batch_over_batch_to_torch(): batch = Batch( a=np.ones((1,), dtype=np.float64), diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 05973f4..16fbdda 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -18,8 +18,8 @@ class MyPolicy(BasePolicy): def forward(self, batch, state=None): if self.dict_state: - return Batch(act=np.ones(batch.obs['index'].shape[0])) - return Batch(act=np.ones(batch.obs.shape[0])) + return Batch(act=np.ones(len(batch.obs['index']))) + return Batch(act=np.ones(len(batch.obs))) def learn(self): pass diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index f6c1df6..9572db4 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,4 +1,5 @@ import torch +import copy import pprint import warnings import numpy as np @@ -73,29 +74,37 @@ class Batch: [11 22] [6 6] """ - def __new__(cls, **kwargs) -> None: - self = super().__new__(cls) - self._meta = {} - return self - - def __init__(self, **kwargs) -> None: - super().__init__() - for k, v in kwargs.items(): - if isinstance(v, (list, np.ndarray)) \ - and len(v) > 0 and isinstance(v[0], dict) and k != 'info': - self._meta[k] = list(v[0].keys()) - for k_ in v[0].keys(): - k__ = '_' + k + '@' + k_ - self.__dict__[k__] = np.array([ - v[i][k_] for i in range(len(v)) - ]) - elif isinstance(v, dict): - self._meta[k] = list(v.keys()) - for k_, v_ in v.items(): - k__ = '_' + k + '@' + k_ - self.__dict__[k__] = v_ - else: - self.__dict__[k] = v + def __init__(self, + batch_dict: Optional[ + Union[dict, List[dict], np.ndarray]] = None, + **kwargs) -> None: + if isinstance(batch_dict, (list, np.ndarray)) \ + and len(batch_dict) > 0 and isinstance(batch_dict[0], dict): + for k, v in zip(batch_dict[0].keys(), + zip(*[e.values() for e in batch_dict])): + if isinstance(v, (list, np.ndarray)) \ + and len(v) > 0 and isinstance(v[0], dict): + self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v]) + elif isinstance(v[0], np.ndarray): + self.__dict__[k] = np.stack(v, axis=0) + elif isinstance(v[0], torch.Tensor): + self.__dict__[k] = torch.stack(v, dim=0) + elif isinstance(v[0], Batch): + self.__dict__[k] = Batch.stack(v) + elif isinstance(v[0], dict): + self.__dict__[k] = Batch(v) + else: + self.__dict__[k] = list(v) + elif isinstance(batch_dict, dict): + for k, v in batch_dict.items(): + if isinstance(v, dict) \ + or (isinstance(v, (list, np.ndarray)) + and len(v) > 0 and isinstance(v[0], dict)): + self.__dict__[k] = Batch(v) + else: + self.__dict__[k] = v + if len(kwargs) > 0: + self.__init__(kwargs) def __getstate__(self): """Pickling interface. Only the actual data are serialized @@ -122,33 +131,25 @@ class Batch: return self.__getattr__(index) b = Batch() for k, v in self.__dict__.items(): - if k != '_meta' and hasattr(v, '__len__'): + if hasattr(v, '__len__'): try: b.__dict__.update(**{k: v[index]}) except IndexError: continue - b._meta = self._meta return b def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" - if key not in self._meta.keys(): - if key not in self.__dict__: - raise AttributeError(key) - return self.__dict__[key] - d = {} - for k_ in self._meta[key]: - k__ = '_' + key + '@' + k_ - d[k_] = self.__dict__[k__] - return Batch(**d) + if key not in self.__dict__: + raise AttributeError(key) + return self.__dict__[key] def __repr__(self) -> str: """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False - for k in sorted(list(self.__dict__) + list(self._meta)): - if k[0] != '_' and (self.__dict__.get(k, None) is not None or - k in self._meta): + for k in sorted(self.__dict__.keys()): + if self.__dict__.get(k, None) is not None: rpl = '\n' + ' ' * (6 + len(k)) obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl) s += f' {k}: {obj},\n' @@ -161,16 +162,19 @@ class Batch: def keys(self) -> List[str]: """Return self.keys().""" - return sorted(list(self._meta.keys()) + - [k for k in self.__dict__.keys() if k[0] != '_']) + return self.__dict__.keys() def values(self) -> List[Any]: """Return self.values().""" - return [self[k] for k in self.keys()] + return self.__dict__.values() + + def items(self) -> Any: + """Return self.items().""" + return self.__dict__.items() def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: """Return self[k] if k in self else d. d defaults to None.""" - if k in self.__dict__ or k in self._meta: + if k in self.__dict__: return self.__getattr__(k) return d @@ -220,35 +224,55 @@ class Batch: def append(self, batch: 'Batch') -> None: warnings.warn('Method append will be removed soon, please use ' ':meth:`~tianshou.data.Batch.cat`') - return self.cat(batch) + return self.cat_(batch) - def cat(self, batch: 'Batch') -> None: + def cat_(self, batch: 'Batch') -> None: """Concatenate a :class:`~tianshou.data.Batch` object to current batch. """ assert isinstance(batch, Batch), \ - 'Only Batch is allowed to be concatenated!' + 'Only Batch is allowed to be concatenated in-place!' for k, v in batch.__dict__.items(): - if k == '_meta': - self._meta.update(batch._meta) - continue if v is None: continue if not hasattr(self, k) or self.__dict__[k] is None: - self.__dict__[k] = v + self.__dict__[k] = copy.deepcopy(v) elif isinstance(v, np.ndarray): self.__dict__[k] = np.concatenate([self.__dict__[k], v]) elif isinstance(v, torch.Tensor): self.__dict__[k] = torch.cat([self.__dict__[k], v]) elif isinstance(v, list): - self.__dict__[k] += v + self.__dict__[k] += copy.deepcopy(v) elif isinstance(v, Batch): - self.__dict__[k].cat(v) + self.__dict__[k].cat_(v) else: - s = f'No support for method "cat" with type \ - {type(v)} in class Batch.' + s = 'No support for method "cat" with type '\ + f'{type(v)} in class Batch.' raise TypeError(s) + @staticmethod + def cat(batches: List['Batch']) -> None: + """Concatenate a :class:`~tianshou.data.Batch` object into a + single new batch. + """ + assert isinstance(batches, (tuple, list)), \ + 'Only list of Batch instances is allowed to be '\ + 'concatenated out-of-place!' + batch = Batch() + for batch_ in batches: + batch.cat_(batch_) + return batch + + @staticmethod + def stack(batches: List['Batch']): + """Stack a :class:`~tianshou.data.Batch` object into a + single new batch. + """ + assert isinstance(batches, (tuple, list)), \ + 'Only list of Batch instances is allowed to be '\ + 'stacked out-of-place!' + return Batch(np.array([batch.__dict__ for batch in batches])) + def __len__(self) -> int: """Return len(self).""" r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 92756d5..bd77d45 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): impt_weight=1 / np.power( self._size * (batch.weight / self._weight_sum), self._beta)) - batch.cat(impt_weight) + batch.cat_(impt_weight) self._check_weight_sum() return batch, indice diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 227e9e1..cc6c7f5 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -416,7 +416,7 @@ class Collector(object): if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) - batch_data.cat(batch) + batch_data.cat_(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice)