Batch refactoring (#87)

* Enable to stack Batch instances. Add Batch cat static method. Rename cat in cat_ since inplace.

* Properly handle Batch init using np.array of dict.

* WIP

* Get rid of metadata.

* Update UT. Replace cat by cat_ everywhere.

* Do not sort Batch keys anymore for efficiency. Add items method.

* Fix cat copy issue.

* Add unit test to chack cat and stack methods.

* Remove used import.

* Fix linter issues.

* Fix unit tests.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-06-23 16:50:59 +02:00 committed by GitHub
parent 13828f6309
commit ec270759ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 61 deletions

View File

@ -1,4 +1,5 @@
import torch
import copy
import pickle
import pytest
import numpy as np
@ -11,7 +12,7 @@ def test_batch():
assert batch.obs == batch["obs"]
batch.obs = [1]
assert batch.obs == [1]
batch.cat(batch)
batch.cat_(batch)
assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4)
assert batch[0].obs == batch[1].obs
@ -25,27 +26,48 @@ def test_batch():
with pytest.raises(AttributeError):
b.obs
print(batch)
batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
batch_item = Batch({'a': [batch_dict]})[0]
assert isinstance(batch_item.a.b, np.ndarray)
assert batch_item.a.b == batch_dict['b']
assert isinstance(batch_item.a.c, float)
assert batch_item.a.c == batch_dict['c']
assert isinstance(batch_item.a.d, torch.Tensor)
assert batch_item.a.d == batch_dict['d']
def test_batch_over_batch():
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
batch2 = Batch(c=[6, 7, 8], b=batch)
batch2 = Batch({'c': [6, 7, 8], 'b': batch})
batch2.b.b[-1] = 0
print(batch2)
assert batch2.values()[-1] == batch2.c
for k, v in batch2.items():
assert batch2[k] == v
assert batch2[-1].b.b == 0
batch2.cat(Batch(c=[6, 7, 8], b=batch))
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
assert batch2.c == [6, 7, 8, 6, 7, 8]
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d)
batch3.cat(Batch(c=[6, 7, 8], b=d))
batch3.cat_(Batch(c=[6, 7, 8], b=d))
assert batch3.c == [6, 7, 8, 6, 7, 8]
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
def test_batch_cat_and_stack():
b1 = Batch(a=[{'b': np.array([1.0]), 'd': Batch(e=np.array([3.0]))}])
b2 = Batch(a=[{'b': np.array([4.0]), 'd': Batch(e=np.array([6.0]))}])
b_cat_out = Batch.cat((b1, b2))
b_cat_in = copy.deepcopy(b1)
b_cat_in.cat_(b2)
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
assert b_cat_in.a.d.e.ndim == 2
b_stack = Batch.stack((b1, b2))
assert b_stack.a.d.e.ndim == 3
def test_batch_over_batch_to_torch():
batch = Batch(
a=np.ones((1,), dtype=np.float64),

View File

@ -18,8 +18,8 @@ class MyPolicy(BasePolicy):
def forward(self, batch, state=None):
if self.dict_state:
return Batch(act=np.ones(batch.obs['index'].shape[0]))
return Batch(act=np.ones(batch.obs.shape[0]))
return Batch(act=np.ones(len(batch.obs['index'])))
return Batch(act=np.ones(len(batch.obs)))
def learn(self):
pass

View File

@ -1,4 +1,5 @@
import torch
import copy
import pprint
import warnings
import numpy as np
@ -73,29 +74,37 @@ class Batch:
[11 22] [6 6]
"""
def __new__(cls, **kwargs) -> None:
self = super().__new__(cls)
self._meta = {}
return self
def __init__(self, **kwargs) -> None:
super().__init__()
for k, v in kwargs.items():
if isinstance(v, (list, np.ndarray)) \
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
self._meta[k] = list(v[0].keys())
for k_ in v[0].keys():
k__ = '_' + k + '@' + k_
self.__dict__[k__] = np.array([
v[i][k_] for i in range(len(v))
])
elif isinstance(v, dict):
self._meta[k] = list(v.keys())
for k_, v_ in v.items():
k__ = '_' + k + '@' + k_
self.__dict__[k__] = v_
else:
self.__dict__[k] = v
def __init__(self,
batch_dict: Optional[
Union[dict, List[dict], np.ndarray]] = None,
**kwargs) -> None:
if isinstance(batch_dict, (list, np.ndarray)) \
and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
for k, v in zip(batch_dict[0].keys(),
zip(*[e.values() for e in batch_dict])):
if isinstance(v, (list, np.ndarray)) \
and len(v) > 0 and isinstance(v[0], dict):
self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v])
elif isinstance(v[0], np.ndarray):
self.__dict__[k] = np.stack(v, axis=0)
elif isinstance(v[0], torch.Tensor):
self.__dict__[k] = torch.stack(v, dim=0)
elif isinstance(v[0], Batch):
self.__dict__[k] = Batch.stack(v)
elif isinstance(v[0], dict):
self.__dict__[k] = Batch(v)
else:
self.__dict__[k] = list(v)
elif isinstance(batch_dict, dict):
for k, v in batch_dict.items():
if isinstance(v, dict) \
or (isinstance(v, (list, np.ndarray))
and len(v) > 0 and isinstance(v[0], dict)):
self.__dict__[k] = Batch(v)
else:
self.__dict__[k] = v
if len(kwargs) > 0:
self.__init__(kwargs)
def __getstate__(self):
"""Pickling interface. Only the actual data are serialized
@ -122,33 +131,25 @@ class Batch:
return self.__getattr__(index)
b = Batch()
for k, v in self.__dict__.items():
if k != '_meta' and hasattr(v, '__len__'):
if hasattr(v, '__len__'):
try:
b.__dict__.update(**{k: v[index]})
except IndexError:
continue
b._meta = self._meta
return b
def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key"""
if key not in self._meta.keys():
if key not in self.__dict__:
raise AttributeError(key)
return self.__dict__[key]
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
d[k_] = self.__dict__[k__]
return Batch(**d)
if key not in self.__dict__:
raise AttributeError(key)
return self.__dict__[key]
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in sorted(list(self.__dict__) + list(self._meta)):
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
k in self._meta):
for k in sorted(self.__dict__.keys()):
if self.__dict__.get(k, None) is not None:
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n'
@ -161,16 +162,19 @@ class Batch:
def keys(self) -> List[str]:
"""Return self.keys()."""
return sorted(list(self._meta.keys()) +
[k for k in self.__dict__.keys() if k[0] != '_'])
return self.__dict__.keys()
def values(self) -> List[Any]:
"""Return self.values()."""
return [self[k] for k in self.keys()]
return self.__dict__.values()
def items(self) -> Any:
"""Return self.items()."""
return self.__dict__.items()
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
"""Return self[k] if k in self else d. d defaults to None."""
if k in self.__dict__ or k in self._meta:
if k in self.__dict__:
return self.__getattr__(k)
return d
@ -220,35 +224,55 @@ class Batch:
def append(self, batch: 'Batch') -> None:
warnings.warn('Method append will be removed soon, please use '
':meth:`~tianshou.data.Batch.cat`')
return self.cat(batch)
return self.cat_(batch)
def cat(self, batch: 'Batch') -> None:
def cat_(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object to current
batch.
"""
assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated!'
'Only Batch is allowed to be concatenated in-place!'
for k, v in batch.__dict__.items():
if k == '_meta':
self._meta.update(batch._meta)
continue
if v is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = v
self.__dict__[k] = copy.deepcopy(v)
elif isinstance(v, np.ndarray):
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, list):
self.__dict__[k] += v
self.__dict__[k] += copy.deepcopy(v)
elif isinstance(v, Batch):
self.__dict__[k].cat(v)
self.__dict__[k].cat_(v)
else:
s = f'No support for method "cat" with type \
{type(v)} in class Batch.'
s = 'No support for method "cat" with type '\
f'{type(v)} in class Batch.'
raise TypeError(s)
@staticmethod
def cat(batches: List['Batch']) -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object into a
single new batch.
"""
assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\
'concatenated out-of-place!'
batch = Batch()
for batch_ in batches:
batch.cat_(batch_)
return batch
@staticmethod
def stack(batches: List['Batch']):
"""Stack a :class:`~tianshou.data.Batch` object into a
single new batch.
"""
assert isinstance(batches, (tuple, list)), \
'Only list of Batch instances is allowed to be '\
'stacked out-of-place!'
return Batch(np.array([batch.__dict__ for batch in batches]))
def __len__(self) -> int:
"""Return len(self)."""
r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')]

View File

@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
impt_weight=1 / np.power(
self._size * (batch.weight / self._weight_sum),
self._beta))
batch.cat(impt_weight)
batch.cat_(impt_weight)
self._check_weight_sum()
return batch, indice

View File

@ -416,7 +416,7 @@ class Collector(object):
if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice)
batch_data.cat(batch)
batch_data.cat_(batch)
else:
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)