Batch refactoring (#87)
* Enable to stack Batch instances. Add Batch cat static method. Rename cat in cat_ since inplace. * Properly handle Batch init using np.array of dict. * WIP * Get rid of metadata. * Update UT. Replace cat by cat_ everywhere. * Do not sort Batch keys anymore for efficiency. Add items method. * Fix cat copy issue. * Add unit test to chack cat and stack methods. * Remove used import. * Fix linter issues. * Fix unit tests. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
13828f6309
commit
ec270759ab
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import copy
|
||||
import pickle
|
||||
import pytest
|
||||
import numpy as np
|
||||
@ -11,7 +12,7 @@ def test_batch():
|
||||
assert batch.obs == batch["obs"]
|
||||
batch.obs = [1]
|
||||
assert batch.obs == [1]
|
||||
batch.cat(batch)
|
||||
batch.cat_(batch)
|
||||
assert batch.obs == [1, 1]
|
||||
assert batch.np.shape == (6, 4)
|
||||
assert batch[0].obs == batch[1].obs
|
||||
@ -25,27 +26,48 @@ def test_batch():
|
||||
with pytest.raises(AttributeError):
|
||||
b.obs
|
||||
print(batch)
|
||||
batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
|
||||
batch_item = Batch({'a': [batch_dict]})[0]
|
||||
assert isinstance(batch_item.a.b, np.ndarray)
|
||||
assert batch_item.a.b == batch_dict['b']
|
||||
assert isinstance(batch_item.a.c, float)
|
||||
assert batch_item.a.c == batch_dict['c']
|
||||
assert isinstance(batch_item.a.d, torch.Tensor)
|
||||
assert batch_item.a.d == batch_dict['d']
|
||||
|
||||
|
||||
def test_batch_over_batch():
|
||||
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
||||
batch2 = Batch(c=[6, 7, 8], b=batch)
|
||||
batch2 = Batch({'c': [6, 7, 8], 'b': batch})
|
||||
batch2.b.b[-1] = 0
|
||||
print(batch2)
|
||||
assert batch2.values()[-1] == batch2.c
|
||||
for k, v in batch2.items():
|
||||
assert batch2[k] == v
|
||||
assert batch2[-1].b.b == 0
|
||||
batch2.cat(Batch(c=[6, 7, 8], b=batch))
|
||||
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
|
||||
assert batch2.c == [6, 7, 8, 6, 7, 8]
|
||||
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
|
||||
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
|
||||
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
||||
batch3 = Batch(c=[6, 7, 8], b=d)
|
||||
batch3.cat(Batch(c=[6, 7, 8], b=d))
|
||||
batch3.cat_(Batch(c=[6, 7, 8], b=d))
|
||||
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
||||
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
||||
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
||||
|
||||
|
||||
def test_batch_cat_and_stack():
|
||||
b1 = Batch(a=[{'b': np.array([1.0]), 'd': Batch(e=np.array([3.0]))}])
|
||||
b2 = Batch(a=[{'b': np.array([4.0]), 'd': Batch(e=np.array([6.0]))}])
|
||||
b_cat_out = Batch.cat((b1, b2))
|
||||
b_cat_in = copy.deepcopy(b1)
|
||||
b_cat_in.cat_(b2)
|
||||
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
|
||||
assert b_cat_in.a.d.e.ndim == 2
|
||||
b_stack = Batch.stack((b1, b2))
|
||||
assert b_stack.a.d.e.ndim == 3
|
||||
|
||||
|
||||
def test_batch_over_batch_to_torch():
|
||||
batch = Batch(
|
||||
a=np.ones((1,), dtype=np.float64),
|
||||
|
@ -18,8 +18,8 @@ class MyPolicy(BasePolicy):
|
||||
|
||||
def forward(self, batch, state=None):
|
||||
if self.dict_state:
|
||||
return Batch(act=np.ones(batch.obs['index'].shape[0]))
|
||||
return Batch(act=np.ones(batch.obs.shape[0]))
|
||||
return Batch(act=np.ones(len(batch.obs['index'])))
|
||||
return Batch(act=np.ones(len(batch.obs)))
|
||||
|
||||
def learn(self):
|
||||
pass
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import copy
|
||||
import pprint
|
||||
import warnings
|
||||
import numpy as np
|
||||
@ -73,29 +74,37 @@ class Batch:
|
||||
[11 22] [6 6]
|
||||
"""
|
||||
|
||||
def __new__(cls, **kwargs) -> None:
|
||||
self = super().__new__(cls)
|
||||
self._meta = {}
|
||||
return self
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, (list, np.ndarray)) \
|
||||
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
||||
self._meta[k] = list(v[0].keys())
|
||||
for k_ in v[0].keys():
|
||||
k__ = '_' + k + '@' + k_
|
||||
self.__dict__[k__] = np.array([
|
||||
v[i][k_] for i in range(len(v))
|
||||
])
|
||||
elif isinstance(v, dict):
|
||||
self._meta[k] = list(v.keys())
|
||||
for k_, v_ in v.items():
|
||||
k__ = '_' + k + '@' + k_
|
||||
self.__dict__[k__] = v_
|
||||
else:
|
||||
self.__dict__[k] = v
|
||||
def __init__(self,
|
||||
batch_dict: Optional[
|
||||
Union[dict, List[dict], np.ndarray]] = None,
|
||||
**kwargs) -> None:
|
||||
if isinstance(batch_dict, (list, np.ndarray)) \
|
||||
and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
|
||||
for k, v in zip(batch_dict[0].keys(),
|
||||
zip(*[e.values() for e in batch_dict])):
|
||||
if isinstance(v, (list, np.ndarray)) \
|
||||
and len(v) > 0 and isinstance(v[0], dict):
|
||||
self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v])
|
||||
elif isinstance(v[0], np.ndarray):
|
||||
self.__dict__[k] = np.stack(v, axis=0)
|
||||
elif isinstance(v[0], torch.Tensor):
|
||||
self.__dict__[k] = torch.stack(v, dim=0)
|
||||
elif isinstance(v[0], Batch):
|
||||
self.__dict__[k] = Batch.stack(v)
|
||||
elif isinstance(v[0], dict):
|
||||
self.__dict__[k] = Batch(v)
|
||||
else:
|
||||
self.__dict__[k] = list(v)
|
||||
elif isinstance(batch_dict, dict):
|
||||
for k, v in batch_dict.items():
|
||||
if isinstance(v, dict) \
|
||||
or (isinstance(v, (list, np.ndarray))
|
||||
and len(v) > 0 and isinstance(v[0], dict)):
|
||||
self.__dict__[k] = Batch(v)
|
||||
else:
|
||||
self.__dict__[k] = v
|
||||
if len(kwargs) > 0:
|
||||
self.__init__(kwargs)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Pickling interface. Only the actual data are serialized
|
||||
@ -122,33 +131,25 @@ class Batch:
|
||||
return self.__getattr__(index)
|
||||
b = Batch()
|
||||
for k, v in self.__dict__.items():
|
||||
if k != '_meta' and hasattr(v, '__len__'):
|
||||
if hasattr(v, '__len__'):
|
||||
try:
|
||||
b.__dict__.update(**{k: v[index]})
|
||||
except IndexError:
|
||||
continue
|
||||
b._meta = self._meta
|
||||
return b
|
||||
|
||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||
"""Return self.key"""
|
||||
if key not in self._meta.keys():
|
||||
if key not in self.__dict__:
|
||||
raise AttributeError(key)
|
||||
return self.__dict__[key]
|
||||
d = {}
|
||||
for k_ in self._meta[key]:
|
||||
k__ = '_' + key + '@' + k_
|
||||
d[k_] = self.__dict__[k__]
|
||||
return Batch(**d)
|
||||
if key not in self.__dict__:
|
||||
raise AttributeError(key)
|
||||
return self.__dict__[key]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
for k in sorted(list(self.__dict__) + list(self._meta)):
|
||||
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
|
||||
k in self._meta):
|
||||
for k in sorted(self.__dict__.keys()):
|
||||
if self.__dict__.get(k, None) is not None:
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
@ -161,16 +162,19 @@ class Batch:
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
"""Return self.keys()."""
|
||||
return sorted(list(self._meta.keys()) +
|
||||
[k for k in self.__dict__.keys() if k[0] != '_'])
|
||||
return self.__dict__.keys()
|
||||
|
||||
def values(self) -> List[Any]:
|
||||
"""Return self.values()."""
|
||||
return [self[k] for k in self.keys()]
|
||||
return self.__dict__.values()
|
||||
|
||||
def items(self) -> Any:
|
||||
"""Return self.items()."""
|
||||
return self.__dict__.items()
|
||||
|
||||
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
|
||||
"""Return self[k] if k in self else d. d defaults to None."""
|
||||
if k in self.__dict__ or k in self._meta:
|
||||
if k in self.__dict__:
|
||||
return self.__getattr__(k)
|
||||
return d
|
||||
|
||||
@ -220,35 +224,55 @@ class Batch:
|
||||
def append(self, batch: 'Batch') -> None:
|
||||
warnings.warn('Method append will be removed soon, please use '
|
||||
':meth:`~tianshou.data.Batch.cat`')
|
||||
return self.cat(batch)
|
||||
return self.cat_(batch)
|
||||
|
||||
def cat(self, batch: 'Batch') -> None:
|
||||
def cat_(self, batch: 'Batch') -> None:
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object to current
|
||||
batch.
|
||||
"""
|
||||
assert isinstance(batch, Batch), \
|
||||
'Only Batch is allowed to be concatenated!'
|
||||
'Only Batch is allowed to be concatenated in-place!'
|
||||
for k, v in batch.__dict__.items():
|
||||
if k == '_meta':
|
||||
self._meta.update(batch._meta)
|
||||
continue
|
||||
if v is None:
|
||||
continue
|
||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||
self.__dict__[k] = v
|
||||
self.__dict__[k] = copy.deepcopy(v)
|
||||
elif isinstance(v, np.ndarray):
|
||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
||||
elif isinstance(v, list):
|
||||
self.__dict__[k] += v
|
||||
self.__dict__[k] += copy.deepcopy(v)
|
||||
elif isinstance(v, Batch):
|
||||
self.__dict__[k].cat(v)
|
||||
self.__dict__[k].cat_(v)
|
||||
else:
|
||||
s = f'No support for method "cat" with type \
|
||||
{type(v)} in class Batch.'
|
||||
s = 'No support for method "cat" with type '\
|
||||
f'{type(v)} in class Batch.'
|
||||
raise TypeError(s)
|
||||
|
||||
@staticmethod
|
||||
def cat(batches: List['Batch']) -> None:
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into a
|
||||
single new batch.
|
||||
"""
|
||||
assert isinstance(batches, (tuple, list)), \
|
||||
'Only list of Batch instances is allowed to be '\
|
||||
'concatenated out-of-place!'
|
||||
batch = Batch()
|
||||
for batch_ in batches:
|
||||
batch.cat_(batch_)
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def stack(batches: List['Batch']):
|
||||
"""Stack a :class:`~tianshou.data.Batch` object into a
|
||||
single new batch.
|
||||
"""
|
||||
assert isinstance(batches, (tuple, list)), \
|
||||
'Only list of Batch instances is allowed to be '\
|
||||
'stacked out-of-place!'
|
||||
return Batch(np.array([batch.__dict__ for batch in batches]))
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')]
|
||||
|
@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
impt_weight=1 / np.power(
|
||||
self._size * (batch.weight / self._weight_sum),
|
||||
self._beta))
|
||||
batch.cat(impt_weight)
|
||||
batch.cat_(impt_weight)
|
||||
self._check_weight_sum()
|
||||
return batch, indice
|
||||
|
||||
|
@ -416,7 +416,7 @@ class Collector(object):
|
||||
if batch_size and cur_batch or batch_size <= 0:
|
||||
batch, indice = b.sample(cur_batch)
|
||||
batch = self.process_fn(batch, b, indice)
|
||||
batch_data.cat(batch)
|
||||
batch_data.cat_(batch)
|
||||
else:
|
||||
batch_data, indice = self.buffer.sample(batch_size)
|
||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||
|
Loading…
x
Reference in New Issue
Block a user