Batch refactoring (#87)
* Enable to stack Batch instances. Add Batch cat static method. Rename cat in cat_ since inplace. * Properly handle Batch init using np.array of dict. * WIP * Get rid of metadata. * Update UT. Replace cat by cat_ everywhere. * Do not sort Batch keys anymore for efficiency. Add items method. * Fix cat copy issue. * Add unit test to chack cat and stack methods. * Remove used import. * Fix linter issues. * Fix unit tests. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
13828f6309
commit
ec270759ab
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import copy
|
||||||
import pickle
|
import pickle
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -11,7 +12,7 @@ def test_batch():
|
|||||||
assert batch.obs == batch["obs"]
|
assert batch.obs == batch["obs"]
|
||||||
batch.obs = [1]
|
batch.obs = [1]
|
||||||
assert batch.obs == [1]
|
assert batch.obs == [1]
|
||||||
batch.cat(batch)
|
batch.cat_(batch)
|
||||||
assert batch.obs == [1, 1]
|
assert batch.obs == [1, 1]
|
||||||
assert batch.np.shape == (6, 4)
|
assert batch.np.shape == (6, 4)
|
||||||
assert batch[0].obs == batch[1].obs
|
assert batch[0].obs == batch[1].obs
|
||||||
@ -25,27 +26,48 @@ def test_batch():
|
|||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
b.obs
|
b.obs
|
||||||
print(batch)
|
print(batch)
|
||||||
|
batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
|
||||||
|
batch_item = Batch({'a': [batch_dict]})[0]
|
||||||
|
assert isinstance(batch_item.a.b, np.ndarray)
|
||||||
|
assert batch_item.a.b == batch_dict['b']
|
||||||
|
assert isinstance(batch_item.a.c, float)
|
||||||
|
assert batch_item.a.c == batch_dict['c']
|
||||||
|
assert isinstance(batch_item.a.d, torch.Tensor)
|
||||||
|
assert batch_item.a.d == batch_dict['d']
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch():
|
def test_batch_over_batch():
|
||||||
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
batch = Batch(a=[3, 4, 5], b=[4, 5, 6])
|
||||||
batch2 = Batch(c=[6, 7, 8], b=batch)
|
batch2 = Batch({'c': [6, 7, 8], 'b': batch})
|
||||||
batch2.b.b[-1] = 0
|
batch2.b.b[-1] = 0
|
||||||
print(batch2)
|
print(batch2)
|
||||||
assert batch2.values()[-1] == batch2.c
|
for k, v in batch2.items():
|
||||||
|
assert batch2[k] == v
|
||||||
assert batch2[-1].b.b == 0
|
assert batch2[-1].b.b == 0
|
||||||
batch2.cat(Batch(c=[6, 7, 8], b=batch))
|
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
|
||||||
assert batch2.c == [6, 7, 8, 6, 7, 8]
|
assert batch2.c == [6, 7, 8, 6, 7, 8]
|
||||||
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
|
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
|
||||||
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
|
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
|
||||||
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
||||||
batch3 = Batch(c=[6, 7, 8], b=d)
|
batch3 = Batch(c=[6, 7, 8], b=d)
|
||||||
batch3.cat(Batch(c=[6, 7, 8], b=d))
|
batch3.cat_(Batch(c=[6, 7, 8], b=d))
|
||||||
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
assert batch3.c == [6, 7, 8, 6, 7, 8]
|
||||||
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
|
||||||
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_cat_and_stack():
|
||||||
|
b1 = Batch(a=[{'b': np.array([1.0]), 'd': Batch(e=np.array([3.0]))}])
|
||||||
|
b2 = Batch(a=[{'b': np.array([4.0]), 'd': Batch(e=np.array([6.0]))}])
|
||||||
|
b_cat_out = Batch.cat((b1, b2))
|
||||||
|
b_cat_in = copy.deepcopy(b1)
|
||||||
|
b_cat_in.cat_(b2)
|
||||||
|
assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
|
||||||
|
assert b_cat_in.a.d.e.ndim == 2
|
||||||
|
b_stack = Batch.stack((b1, b2))
|
||||||
|
assert b_stack.a.d.e.ndim == 3
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch_to_torch():
|
def test_batch_over_batch_to_torch():
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
a=np.ones((1,), dtype=np.float64),
|
a=np.ones((1,), dtype=np.float64),
|
||||||
|
@ -18,8 +18,8 @@ class MyPolicy(BasePolicy):
|
|||||||
|
|
||||||
def forward(self, batch, state=None):
|
def forward(self, batch, state=None):
|
||||||
if self.dict_state:
|
if self.dict_state:
|
||||||
return Batch(act=np.ones(batch.obs['index'].shape[0]))
|
return Batch(act=np.ones(len(batch.obs['index'])))
|
||||||
return Batch(act=np.ones(batch.obs.shape[0]))
|
return Batch(act=np.ones(len(batch.obs)))
|
||||||
|
|
||||||
def learn(self):
|
def learn(self):
|
||||||
pass
|
pass
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import copy
|
||||||
import pprint
|
import pprint
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -73,29 +74,37 @@ class Batch:
|
|||||||
[11 22] [6 6]
|
[11 22] [6 6]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, **kwargs) -> None:
|
def __init__(self,
|
||||||
self = super().__new__(cls)
|
batch_dict: Optional[
|
||||||
self._meta = {}
|
Union[dict, List[dict], np.ndarray]] = None,
|
||||||
return self
|
**kwargs) -> None:
|
||||||
|
if isinstance(batch_dict, (list, np.ndarray)) \
|
||||||
def __init__(self, **kwargs) -> None:
|
and len(batch_dict) > 0 and isinstance(batch_dict[0], dict):
|
||||||
super().__init__()
|
for k, v in zip(batch_dict[0].keys(),
|
||||||
for k, v in kwargs.items():
|
zip(*[e.values() for e in batch_dict])):
|
||||||
if isinstance(v, (list, np.ndarray)) \
|
if isinstance(v, (list, np.ndarray)) \
|
||||||
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
and len(v) > 0 and isinstance(v[0], dict):
|
||||||
self._meta[k] = list(v[0].keys())
|
self.__dict__[k] = Batch.stack([Batch(v_) for v_ in v])
|
||||||
for k_ in v[0].keys():
|
elif isinstance(v[0], np.ndarray):
|
||||||
k__ = '_' + k + '@' + k_
|
self.__dict__[k] = np.stack(v, axis=0)
|
||||||
self.__dict__[k__] = np.array([
|
elif isinstance(v[0], torch.Tensor):
|
||||||
v[i][k_] for i in range(len(v))
|
self.__dict__[k] = torch.stack(v, dim=0)
|
||||||
])
|
elif isinstance(v[0], Batch):
|
||||||
elif isinstance(v, dict):
|
self.__dict__[k] = Batch.stack(v)
|
||||||
self._meta[k] = list(v.keys())
|
elif isinstance(v[0], dict):
|
||||||
for k_, v_ in v.items():
|
self.__dict__[k] = Batch(v)
|
||||||
k__ = '_' + k + '@' + k_
|
else:
|
||||||
self.__dict__[k__] = v_
|
self.__dict__[k] = list(v)
|
||||||
else:
|
elif isinstance(batch_dict, dict):
|
||||||
self.__dict__[k] = v
|
for k, v in batch_dict.items():
|
||||||
|
if isinstance(v, dict) \
|
||||||
|
or (isinstance(v, (list, np.ndarray))
|
||||||
|
and len(v) > 0 and isinstance(v[0], dict)):
|
||||||
|
self.__dict__[k] = Batch(v)
|
||||||
|
else:
|
||||||
|
self.__dict__[k] = v
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
self.__init__(kwargs)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
"""Pickling interface. Only the actual data are serialized
|
"""Pickling interface. Only the actual data are serialized
|
||||||
@ -122,33 +131,25 @@ class Batch:
|
|||||||
return self.__getattr__(index)
|
return self.__getattr__(index)
|
||||||
b = Batch()
|
b = Batch()
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k != '_meta' and hasattr(v, '__len__'):
|
if hasattr(v, '__len__'):
|
||||||
try:
|
try:
|
||||||
b.__dict__.update(**{k: v[index]})
|
b.__dict__.update(**{k: v[index]})
|
||||||
except IndexError:
|
except IndexError:
|
||||||
continue
|
continue
|
||||||
b._meta = self._meta
|
|
||||||
return b
|
return b
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||||
"""Return self.key"""
|
"""Return self.key"""
|
||||||
if key not in self._meta.keys():
|
if key not in self.__dict__:
|
||||||
if key not in self.__dict__:
|
raise AttributeError(key)
|
||||||
raise AttributeError(key)
|
return self.__dict__[key]
|
||||||
return self.__dict__[key]
|
|
||||||
d = {}
|
|
||||||
for k_ in self._meta[key]:
|
|
||||||
k__ = '_' + key + '@' + k_
|
|
||||||
d[k_] = self.__dict__[k__]
|
|
||||||
return Batch(**d)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Return str(self)."""
|
"""Return str(self)."""
|
||||||
s = self.__class__.__name__ + '(\n'
|
s = self.__class__.__name__ + '(\n'
|
||||||
flag = False
|
flag = False
|
||||||
for k in sorted(list(self.__dict__) + list(self._meta)):
|
for k in sorted(self.__dict__.keys()):
|
||||||
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
|
if self.__dict__.get(k, None) is not None:
|
||||||
k in self._meta):
|
|
||||||
rpl = '\n' + ' ' * (6 + len(k))
|
rpl = '\n' + ' ' * (6 + len(k))
|
||||||
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
||||||
s += f' {k}: {obj},\n'
|
s += f' {k}: {obj},\n'
|
||||||
@ -161,16 +162,19 @@ class Batch:
|
|||||||
|
|
||||||
def keys(self) -> List[str]:
|
def keys(self) -> List[str]:
|
||||||
"""Return self.keys()."""
|
"""Return self.keys()."""
|
||||||
return sorted(list(self._meta.keys()) +
|
return self.__dict__.keys()
|
||||||
[k for k in self.__dict__.keys() if k[0] != '_'])
|
|
||||||
|
|
||||||
def values(self) -> List[Any]:
|
def values(self) -> List[Any]:
|
||||||
"""Return self.values()."""
|
"""Return self.values()."""
|
||||||
return [self[k] for k in self.keys()]
|
return self.__dict__.values()
|
||||||
|
|
||||||
|
def items(self) -> Any:
|
||||||
|
"""Return self.items()."""
|
||||||
|
return self.__dict__.items()
|
||||||
|
|
||||||
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
|
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
|
||||||
"""Return self[k] if k in self else d. d defaults to None."""
|
"""Return self[k] if k in self else d. d defaults to None."""
|
||||||
if k in self.__dict__ or k in self._meta:
|
if k in self.__dict__:
|
||||||
return self.__getattr__(k)
|
return self.__getattr__(k)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -220,35 +224,55 @@ class Batch:
|
|||||||
def append(self, batch: 'Batch') -> None:
|
def append(self, batch: 'Batch') -> None:
|
||||||
warnings.warn('Method append will be removed soon, please use '
|
warnings.warn('Method append will be removed soon, please use '
|
||||||
':meth:`~tianshou.data.Batch.cat`')
|
':meth:`~tianshou.data.Batch.cat`')
|
||||||
return self.cat(batch)
|
return self.cat_(batch)
|
||||||
|
|
||||||
def cat(self, batch: 'Batch') -> None:
|
def cat_(self, batch: 'Batch') -> None:
|
||||||
"""Concatenate a :class:`~tianshou.data.Batch` object to current
|
"""Concatenate a :class:`~tianshou.data.Batch` object to current
|
||||||
batch.
|
batch.
|
||||||
"""
|
"""
|
||||||
assert isinstance(batch, Batch), \
|
assert isinstance(batch, Batch), \
|
||||||
'Only Batch is allowed to be concatenated!'
|
'Only Batch is allowed to be concatenated in-place!'
|
||||||
for k, v in batch.__dict__.items():
|
for k, v in batch.__dict__.items():
|
||||||
if k == '_meta':
|
|
||||||
self._meta.update(batch._meta)
|
|
||||||
continue
|
|
||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = copy.deepcopy(v)
|
||||||
elif isinstance(v, np.ndarray):
|
elif isinstance(v, np.ndarray):
|
||||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
||||||
elif isinstance(v, list):
|
elif isinstance(v, list):
|
||||||
self.__dict__[k] += v
|
self.__dict__[k] += copy.deepcopy(v)
|
||||||
elif isinstance(v, Batch):
|
elif isinstance(v, Batch):
|
||||||
self.__dict__[k].cat(v)
|
self.__dict__[k].cat_(v)
|
||||||
else:
|
else:
|
||||||
s = f'No support for method "cat" with type \
|
s = 'No support for method "cat" with type '\
|
||||||
{type(v)} in class Batch.'
|
f'{type(v)} in class Batch.'
|
||||||
raise TypeError(s)
|
raise TypeError(s)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cat(batches: List['Batch']) -> None:
|
||||||
|
"""Concatenate a :class:`~tianshou.data.Batch` object into a
|
||||||
|
single new batch.
|
||||||
|
"""
|
||||||
|
assert isinstance(batches, (tuple, list)), \
|
||||||
|
'Only list of Batch instances is allowed to be '\
|
||||||
|
'concatenated out-of-place!'
|
||||||
|
batch = Batch()
|
||||||
|
for batch_ in batches:
|
||||||
|
batch.cat_(batch_)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def stack(batches: List['Batch']):
|
||||||
|
"""Stack a :class:`~tianshou.data.Batch` object into a
|
||||||
|
single new batch.
|
||||||
|
"""
|
||||||
|
assert isinstance(batches, (tuple, list)), \
|
||||||
|
'Only list of Batch instances is allowed to be '\
|
||||||
|
'stacked out-of-place!'
|
||||||
|
return Batch(np.array([batch.__dict__ for batch in batches]))
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return len(self)."""
|
"""Return len(self)."""
|
||||||
r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')]
|
r = [len(v) for k, v in self.__dict__.items() if hasattr(v, '__len__')]
|
||||||
|
@ -413,7 +413,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
impt_weight=1 / np.power(
|
impt_weight=1 / np.power(
|
||||||
self._size * (batch.weight / self._weight_sum),
|
self._size * (batch.weight / self._weight_sum),
|
||||||
self._beta))
|
self._beta))
|
||||||
batch.cat(impt_weight)
|
batch.cat_(impt_weight)
|
||||||
self._check_weight_sum()
|
self._check_weight_sum()
|
||||||
return batch, indice
|
return batch, indice
|
||||||
|
|
||||||
|
@ -416,7 +416,7 @@ class Collector(object):
|
|||||||
if batch_size and cur_batch or batch_size <= 0:
|
if batch_size and cur_batch or batch_size <= 0:
|
||||||
batch, indice = b.sample(cur_batch)
|
batch, indice = b.sample(cur_batch)
|
||||||
batch = self.process_fn(batch, b, indice)
|
batch = self.process_fn(batch, b, indice)
|
||||||
batch_data.cat(batch)
|
batch_data.cat_(batch)
|
||||||
else:
|
else:
|
||||||
batch_data, indice = self.buffer.sample(batch_size)
|
batch_data, indice = self.buffer.sample(batch_size)
|
||||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user