Batch refactoring (#87)

* Enable to stack Batch instances. Add Batch cat static method. Rename cat in cat_ since inplace.

* Properly handle Batch init using np.array of dict.

* WIP

* Get rid of metadata.

* Update UT. Replace cat by cat_ everywhere.

* Do not sort Batch keys anymore for efficiency. Add items method.

* Fix cat copy issue.

* Add unit test to chack cat and stack methods.

* Remove used import.

* Fix linter issues.

* Fix unit tests.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-06-23 16:50:59 +02:00 committed by GitHub
parent 13828f6309
commit ec270759ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 61 deletions

View File

@ -1,4 +1,5 @@
import torch import torch
import copy
import pickle import pickle
import pytest import pytest
import numpy as np import numpy as np
@ -11,7 +12,7 @@ def test_batch():
assert batch.obs == batch["obs"] assert batch.obs == batch["obs"]
batch.obs = [1] batch.obs = [1]
assert batch.obs == [1] assert batch.obs == [1]
batch.cat(batch) batch.cat_(batch)
assert batch.obs == [1, 1] assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4) assert batch.np.shape == (6, 4)
assert batch[0].obs == batch[1].obs assert batch[0].obs == batch[1].obs
@ -25,27 +26,48 @@ def test_batch():
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
b.obs b.obs
print(batch) print(batch)
batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
batch_item = Batch({'a': [batch_dict]})[0]
assert isinstance(batch_item.a.b, np.ndarray)
assert batch_item.a.b == batch_dict['b']
assert isinstance(batch_item.a.c, float)
assert batch_item.a.c == batch_dict['c']
assert isinstance(batch_item.a.d, torch.Tensor)
assert batch_item.a.d == batch_dict['d']
def test_batch_over_batch(): def test_batch_over_batch():
batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
batch2 = Batch(c=[6, 7, 8], b=batch) batch2 = Batch({'c': [6, 7, 8], 'b': batch})
batch2.b.b[-1] = 0 batch2.b.b[-1] = 0
print(batch2) print(batch2)
assert batch2.values()[-1] == batch2.c for k, v in batch2.items():
assert batch2[k] == v
assert batch2[-1].b.b == 0 assert batch2[-1].b.b == 0
batch2.cat(Batch(c=[6, 7, 8], b=batch)) batch2.cat_(Batch(c=[6, 7, 8], b=batch))
assert batch2.c == [6, 7, 8, 6, 7, 8] assert batch2.c == [6, 7, 8, 6, 7, 8]
assert batch2.b.a == [3, 4, 5, 3, 4, 5] assert batch2.b.a == [3, 4, 5, 3, 4, 5]
assert batch2.b.b == [4, 5, 0, 4, 5, 0] assert batch2.b.b == [4, 5, 0, 4, 5, 0]
d = {'a': [3, 4, 5], 'b': [4, 5, 6]} d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d) batch3 = Batch(c=[6, 7, 8], b=d)
batch3.cat(Batch(c=[6, 7, 8], b=d)) batch3.cat_(Batch(c=[6, 7, 8], b=d))
assert batch3.c == [6, 7, 8, 6, 7, 8] assert batch3.c == [6, 7, 8, 6, 7, 8]
assert batch3.b.a == [3, 4, 5, 3, 4, 5] assert batch3.b.a == [3, 4, 5, 3, 4, 5]
assert batch3.b.b == [4, 5, 6, 4, 5, 6] assert batch3.b.b == [4, 5, 6, 4, 5, 6]
def test_batch_cat_and_stack():
b1 = Batch(a=[{'b': np.array([1.0]), 'd': Batch(e=np.array([3.0]))}])
b2 = Batch(a=[{'b': np.array([4.0]), 'd': Batch(e=np.array([6.0]))}])
b_cat_out = Batch.cat((b1, b2))
b_cat_in = copy.deepcopy(b1)
b_cat_in.cat_(b2)
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
assert b_cat_in.a.d.e.ndim == 2
b_stack = Batch.stack((b1, b2))
assert b_stack.a.d.e.ndim == 3
def test_batch_over_batch_to_torch(): def test_batch_over_batch_to_torch():
batch = Batch( batch = Batch(
a=np.ones((1,), dtype=np.float64), a=np.ones((1,), dtype=np.float64),

View File

@ -18,8 +18,8 @@ class MyPolicy(BasePolicy):
def forward(self, batch, state=None): def forward(self, batch, state=None):
if self.dict_state: if self.dict_state:
return Batch(act=np.ones(batch.obs['index'].shape[0])) return Batch(act=np.ones(len(batch.obs['index'])))
return Batch(act=np.ones(batch.obs.shape[0])) return Batch(act=np.ones(len(batch.obs)))
def learn(self): def learn(self):
pass pass

View File

@ -1,4 +1,5 @@
import torch import torch
import copy
import pprint import pprint
import warnings import warnings
import numpy as np import numpy as np
@ -73,29 +74,37 @@ class Batch:
[11 22] [6 6] [11 22] [6 6]
""" """
def __new__(cls, **kwargs) -> None: def __init__(self,
self = super().__new__(cls) batch_dict: Optional[
self._meta = {} Union[dict, List[dict], np.ndarray]] = None,
return self **kwargs) -> None:
if isinstance(batch_dict, (list, np.ndarray)) \
def __init__(self, **kwargs) -> None: and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
super().__init__() for k, v in zip(batch_dict[0].keys(),
for k, v in kwargs.items(): zip(*[e.values() for e in batch_dict])):
if isinstance(v, (list, np.ndarray)) \ if isinstance(v, (list, np.ndarray)) \
and len(v) > 0 and isinstance(v[0], dict) and k != 'info': and len(v) > 0 and isinstance(v[0], dict):
self._meta[k] = list(v[0].keys()) self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v])
for k_ in v[0].keys(): elif isinstance(v[0], np.ndarray):
k__ = '_' + k + '@' + k_ self.__dict__[k] = np.stack(v, axis=0)
self.__dict__[k__] = np.array([ elif isinstance(v[0], torch.Tensor):
v[i][k_] for i in range(len(v)) self.__dict__[k] = torch.stack(v, dim=0)
]) elif isinstance(v[0], Batch):
elif isinstance(v, dict): self.__dict__[k] = Batch.stack(v)
self._meta[k] = list(v.keys()) elif isinstance(v[0], dict):
for k_, v_ in v.items(): self.__dict__[k] = Batch(v)
k__ = '_' + k + '@' + k_ else:
self.__dict__[k__] = v_ self.__dict__[k] = list(v)
else: elif isinstance(batch_dict, dict):
self.__dict__[k] = v for k, v in batch_dict.items():
if isinstance(v, dict) \
or (isinstance(v, (list, np.ndarray))
and len(v) > 0 and isinstance(v[0], dict)):
self.__dict__[k] = Batch(v)
else:
self.__dict__[k] = v
if len(kwargs) > 0:
self.__init__(kwargs)
def __getstate__(self): def __getstate__(self):
"""Pickling interface. Only the actual data are serialized """Pickling interface. Only the actual data are serialized
@ -122,33 +131,25 @@ class Batch:
return self.__getattr__(index) return self.__getattr__(index)
b = Batch() b = Batch()
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k != '_meta' and hasattr(v, '__len__'): if hasattr(v, '__len__'):
try: try:
b.__dict__.update(**{k: v[index]}) b.__dict__.update(**{k: v[index]})
except IndexError: except IndexError:
continue continue
b._meta = self._meta
return b return b
def __getattr__(self, key: str) -> Union['Batch', Any]: def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key""" """Return self.key"""
if key not in self._meta.keys(): if key not in self.__dict__:
if key not in self.__dict__: raise AttributeError(key)
raise AttributeError(key) return self.__dict__[key]
return self.__dict__[key]
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
d[k_] = self.__dict__[k__]
return Batch(**d)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return str(self).""" """Return str(self)."""
s = self.__class__.__name__ + '(\n' s = self.__class__.__name__ + '(\n'
flag = False flag = False
for k in sorted(list(self.__dict__) + list(self._meta)): for k in sorted(self.__dict__.keys()):
if k[0] != '_' and (self.__dict__.get(k, None) is not None or if self.__dict__.get(k, None) is not None:
k in self._meta):
rpl = '\n' + ' ' * (6 + len(k)) rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl) obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n' s += f' {k}: {obj},\n'
@ -161,16 +162,19 @@ class Batch:
def keys(self) -> List[str]: def keys(self) -> List[str]:
"""Return self.keys().""" """Return self.keys()."""
return sorted(list(self._meta.keys()) + return self.__dict__.keys()
[k for k in self.__dict__.keys() if k[0] != '_'])
def values(self) -> List[Any]: def values(self) -> List[Any]:
"""Return self.values().""" """Return self.values()."""
return [self[k] for k in self.keys()] return self.__dict__.values()
def items(self) -> Any:
"""Return self.items()."""
return self.__dict__.items()
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
"""Return self[k] if k in self else d. d defaults to None.""" """Return self[k] if k in self else d. d defaults to None."""
if k in self.__dict__ or k in self._meta: if k in self.__dict__:
return self.__getattr__(k) return self.__getattr__(k)
return d return d
@ -220,35 +224,55 @@ class Batch:
def append(self, batch: 'Batch') -> None: def append(self, batch: 'Batch') -> None:
warnings.warn('Method append will be removed soon, please use ' warnings.warn('Method append will be removed soon, please use '
':meth:`~tianshou.data.Batch.cat`') ':meth:`~tianshou.data.Batch.cat`')
return self.cat(batch) return self.cat_(batch)
def cat(self, batch: 'Batch') -> None: def cat_(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object to current """Concatenate a :class:`~tianshou.data.Batch` object to current
batch. batch.
""" """
assert isinstance(batch, Batch), \ assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated!' 'Only Batch is allowed to be concatenated in-place!'
for k, v in batch.__dict__.items(): for k, v in batch.__dict__.items():
if k == '_meta':
self._meta.update(batch._meta)
continue
if v is None: if v is None:
continue continue
if not hasattr(self, k) or self.__dict__[k] is None: if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = v self.__dict__[k] = copy.deepcopy(v)
elif isinstance(v, np.ndarray): elif isinstance(v, np.ndarray):
self.__dict__[k] = np.concatenate([self.__dict__[k], v]) self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v]) self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, list): elif isinstance(v, list):
self.__dict__[k] += v self.__dict__[k] += copy.deepcopy(v)
elif isinstance(v, Batch): elif isinstance(v, Batch):
self.__dict__[k].cat(v) self.__dict__[k].cat_(v)
else: else:
s = f'No support for method "cat" with type \ s = 'No support for method "cat" with type '\
{type(v)} in class Batch.' f'{type(v)} in class Batch.'
raise TypeError(s) raise TypeError(s)
@staticmethod
def cat(batches: List['Batch']) -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object into a
single new batch.
"""
assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\
'concatenated out-of-place!'
batch = Batch()
for batch_ in batches:
batch.cat_(batch_)
return batch
@staticmethod
def stack(batches: List['Batch']):
"""Stack a :class:`~tianshou.data.Batch` object into a
single new batch.
"""
assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\
'stacked out-of-place!'
return Batch(np.array([batch.__dict__ for batch in batches]))
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""
r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')] r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')]

View File

@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
impt_weight=1 / np.power( impt_weight=1 / np.power(
self._size * (batch.weight / self._weight_sum), self._size * (batch.weight / self._weight_sum),
self._beta)) self._beta))
batch.cat(impt_weight) batch.cat_(impt_weight)
self._check_weight_sum() self._check_weight_sum()
return batch, indice return batch, indice

View File

@ -416,7 +416,7 @@ class Collector(object):
if batch_size and cur_batch or batch_size <= 0: if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch) batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice) batch = self.process_fn(batch, b, indice)
batch_data.cat(batch) batch_data.cat_(batch)
else: else:
batch_data, indice = self.buffer.sample(batch_size) batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice) batch_data = self.process_fn(batch_data, self.buffer, indice)