Advanced Batch slicing & minor fix of RNN support (#106)

* add shape property and modify __getitem__

* change Batch.size to Batch.shape

* setattr

* Batch.empty

* remove scalar in advanced slicing

* modify empty_ and __getitem__

* missing testcase

* fix empty
This commit is contained in:
n+e 2020-06-30 18:02:44 +08:00 committed by GitHub
parent c639446c66
commit db0e2e5cd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 183 additions and 101 deletions

View File

@ -5,3 +5,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
* Jiayi Weng (`Trinkle23897 <https://github.com/Trinkle23897>`_)
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)

View File

@ -13,9 +13,9 @@ def test_batch():
batch.obs = [1]
assert batch.obs == [1]
batch.cat_(batch)
assert batch.obs == [1, 1]
assert np.allclose(batch.obs, [1, 1])
assert batch.np.shape == (6, 4)
assert batch[0].obs == batch[1].obs
assert np.allclose(batch[0].obs, batch[1].obs)
batch.obs = np.arange(5)
for i, b in enumerate(batch.split(1, shuffle=False)):
if i != 5:
@ -39,14 +39,14 @@ def test_batch():
'c': np.zeros(1),
'd': Batch(e=np.array(3.0))}])
assert len(batch2) == 1
assert Batch().size == 0
assert batch2.size == 1
assert Batch().shape == []
assert batch2.shape[0] == 1
with pytest.raises(IndexError):
batch2[-2]
with pytest.raises(IndexError):
batch2[1]
assert batch2[0].size == 1
with pytest.raises(TypeError):
assert batch2[0].shape == []
with pytest.raises(IndexError):
batch2[0][0]
with pytest.raises(TypeError):
len(batch2[0])
@ -87,24 +87,36 @@ def test_batch_over_batch():
batch2.b.b[-1] = 0
print(batch2)
for k, v in batch2.items():
assert batch2[k] == v
assert np.all(batch2[k] == v)
assert batch2[-1].b.b == 0
batch2.cat_(Batch(c=[6, 7, 8], b=batch))
assert batch2.c == [6, 7, 8, 6, 7, 8]
assert batch2.b.a == [3, 4, 5, 3, 4, 5]
assert batch2.b.b == [4, 5, 0, 4, 5, 0]
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d)
batch3.cat_(Batch(c=[6, 7, 8], b=d))
assert batch3.c == [6, 7, 8, 6, 7, 8]
assert batch3.b.a == [3, 4, 5, 3, 4, 5]
assert batch3.b.b == [4, 5, 6, 4, 5, 6]
assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
batch4 = Batch(({'a': {'b': np.array([1.0])}},))
assert batch4.a.b.ndim == 2
assert batch4.a.b[0, 0] == 1.0
# advanced slicing
batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])})
assert batch5.shape == [1, 2]
with pytest.raises(IndexError):
batch5[2]
with pytest.raises(IndexError):
batch5[:, 3]
with pytest.raises(IndexError):
batch5[:, :, -1]
batch5[:, -1] += 1
assert np.allclose(batch5.a, [1, 3])
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
def test_batch_cat_and_stack():
def test_batch_cat_and_stack_and_empty():
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat((b1, b2))
@ -133,6 +145,24 @@ def test_batch_cat_and_stack():
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == ['2.0', '']
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
b0 = Batch()
b0.empty_()
assert b0.shape == []
def test_batch_over_batch_to_torch():
@ -215,3 +245,5 @@ if __name__ == '__main__':
test_utils_to_torch()
test_batch_pickle()
test_batch_from_to_numpy_without_copy()
test_batch_numpy_compatibility()
test_batch_cat_and_stack_and_empty()

View File

@ -74,8 +74,9 @@ class Batch:
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312')
>>> # the list will automatically be converted to numpy array
>>> data.b
[5, 5]
array([5, 5])
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
@ -104,8 +105,6 @@ class Batch:
together:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch([{'a': {'b': [0.0, "info"]}}])
>>> print(data[0])
Batch(
@ -119,7 +118,6 @@ class Batch:
key, or iterate over stored data:
::
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5])
>>> print(data["a"])
4
@ -130,28 +128,36 @@ class Batch:
:class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for
arrays. You can access or iterate over the individual samples, if any:
arrays. It also supports the advanced slicing method, such as batch[:, i],
if the index is valid. You can access or iterate over the individual
samples, if any:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[5, -5])
>>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5]])
>>> print(data[0])
Batch(
a: np.array([0.0, 2.0])
b: 5
a: array([0., 2.])
b: array([ 5, -5]),
)
>>> for sample in data:
>>> print(sample.a)
[0.0, 2.0]
[1.0, 3.0]
[0., 2.]
[1., 3.]
>>> print(data.shape)
[1, 2]
>>> data[:, 1] += 1
>>> print(data)
Batch(
a: array([[0., 3.],
[1., 4.]]),
b: array([[ 5, -4]]),
)
Similarly, one can also perform simple algebra on it, and stack, split or
concatenate multiple instances:
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
>>> data = Batch.stack((data_1, data_2))
@ -169,11 +175,10 @@ class Batch:
>>> data_split = list(data.split(1, False))
>>> print(list(data.split(1, False)))
[Batch(
b: [5],
b: array([5]),
a: array([[0., 2.]]),
),
Batch(
b: [-5],
), Batch(
b: array([-5]),
a: array([[1., 3.]]),
)]
>>> data_cat = Batch.cat(data_split)
@ -188,8 +193,6 @@ class Batch:
None is added in list or :class:`np.ndarray` of objects, 0 otherwise.
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data_1 = Batch(a=np.array([0.0, 2.0]))
>>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done')
>>> data = Batch.stack((data_1, data_2))
@ -200,23 +203,40 @@ class Batch:
b: array([None, 'done'], dtype=object),
)
:meth:`~tianshou.data.Batch.size` and :meth:`~tianshou.data.Batch.__len__`
methods are also provided to respectively get the length and the size of
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception, while
the size is 1. The size is only 0 if empty. Note that the size and length
are the identical if multiple samples are stored:
Also with method empty (which will set to 0 or ``None`` (with np.object))
::
>>> data.empty_()
>>> print(data)
Batch(
a: array([[0., 0.],
[0., 0.]]),
b: array([None, None], dtype=object),
)
>>> data = Batch(a=[False, True], b={'c': [2., 'st'], 'd': [1., 0.]})
>>> data[0] = Batch.empty(data[1])
>>> data
Batch(
a: array([False, True]),
b: Batch(
c: array([0., 3.]),
d: array([0., 0.]),
),
)
:meth:`~tianshou.data.Batch.shape` and :meth:`~tianshou.data.Batch.__len__`
methods are also provided to respectively get the shape and the length of
a :class:`Batch` instance. It mimics the Numpy API for Numpy arrays, which
means that getting the length of a scalar Batch raises an exception.
::
>>> import numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4)))
>>> data.size
2
>>> data.shape
[2]
>>> len(data)
2
>>> data[0].size
1
>>> data[0].shape
[]
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
@ -240,13 +260,26 @@ class Batch:
if isinstance(v, dict) or _is_batch_set(v):
self.__dict__[k] = Batch(v)
else:
if isinstance(v, list):
v = np.array(v)
self.__dict__[k] = v
if len(kwargs) > 0:
self.__init__(kwargs)
def __setattr__(self, key: str, value: Any):
"""self[key] = value"""
if isinstance(value, list):
if _is_batch_set(value):
value = Batch(value)
else:
value = np.array(value)
elif isinstance(value, dict):
value = Batch(value)
self.__dict__[key] = value
def __getstate__(self):
"""Pickling interface. Only the actual data are serialized
for both efficiency and simplicity.
"""Pickling interface. Only the actual data are serialized for both
efficiency and simplicity.
"""
state = {}
for k, v in self.items():
@ -256,9 +289,9 @@ class Batch:
return state
def __setstate__(self, state):
"""Unpickling interface. At this point, self is an empty Batch
instance that has not been initialized, so it can safely be
initialized by the pickle state.
"""Unpickling interface. At this point, self is an empty Batch instance
that has not been initialized, so it can safely be initialized by the
pickle state.
"""
self.__init__(**state)
@ -267,26 +300,18 @@ class Batch:
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
b = Batch()
for k, v in self.items():
if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = Batch()
else:
b.__dict__[k] = v[index]
return b
if not _valid_bounds(len(self), index):
raise IndexError(
f"Index {index} out of bounds for Batch of len {len(self)}.")
else:
b = Batch()
is_index_scalar = isinstance(index, (int, np.integer)) or \
(isinstance(index, np.ndarray) and index.ndim == 0)
for k, v in self.items():
if isinstance(v, Batch) and len(v.__dict__) == 0:
b.__dict__[k] = Batch()
elif is_index_scalar or not isinstance(v, list):
b.__dict__[k] = v[index]
else:
b.__dict__[k] = [v[i] for i in index]
return b
def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
def __setitem__(
self,
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(index, str):
self.__dict__[index] = value
@ -319,8 +344,6 @@ class Batch:
other.__dict__.values()):
if r is None:
continue
elif isinstance(r, list):
self.__dict__[k] = [r_ + v_ for r_, v_ in zip(r, v)]
else:
self.__dict__[k] += v
return self
@ -328,8 +351,6 @@ class Batch:
for k, r in self.items():
if r is None:
continue
elif isinstance(r, list):
self.__dict__[k] = [r_ + other for r_ in r]
else:
self.__dict__[k] += other
return self
@ -440,13 +461,14 @@ class Batch:
v.to_torch(dtype, device)
def append(self, batch: 'Batch') -> None:
warnings.warn('Method append will be removed soon, please use '
warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be '
'removed soon, please use '
':meth:`~tianshou.data.Batch.cat`')
return self.cat_(batch)
def cat_(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object into
current batch.
"""Concatenate a :class:`~tianshou.data.Batch` object into current
batch.
"""
assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated in-place!'
@ -459,8 +481,6 @@ class Batch:
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, list):
self.__dict__[k] += copy.deepcopy(v)
elif isinstance(v, Batch):
self.__dict__[k].cat_(v)
else:
@ -468,12 +488,12 @@ class Batch:
f'{type(v)} in class Batch.'
raise TypeError(s)
@classmethod
def cat(cls, batches: List['Batch']) -> 'Batch':
"""Concatenate a :class:`~tianshou.data.Batch` object into a
single new batch.
@staticmethod
def cat(batches: List['Batch']) -> 'Batch':
"""Concatenate a :class:`~tianshou.data.Batch` object into a single
new batch.
"""
batch = cls()
batch = Batch()
for batch_ in batches:
batch.cat_(batch_)
return batch
@ -481,8 +501,7 @@ class Batch:
def stack_(self,
batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None:
"""Stack a :class:`~tianshou.data.Batch` object i into current
batch.
"""Stack a :class:`~tianshou.data.Batch` object i into current batch.
"""
if len(self.__dict__) > 0:
batches = [self] + list(batches)
@ -511,13 +530,42 @@ class Batch:
@staticmethod
def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a
single new batch.
"""Stack a :class:`~tianshou.data.Batch` object into a single new
batch.
"""
batch = Batch()
batch.stack_(batches, axis)
return batch
def empty_(self) -> 'Batch':
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
``None`` filled.
"""
for k, v in self.items():
if v is None:
continue
if isinstance(v, Batch):
self.__dict__[k].empty_()
elif isinstance(v, np.ndarray) and v.dtype == np.object:
self.__dict__[k].fill(None)
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly
self.__dict__[k] = torch.zeros_like(self.__dict__[k])
else: # np
self.__dict__[k] *= 0
if hasattr(v, 'dtype') and v.dtype.kind in 'fc':
self.__dict__[k] = np.nan_to_num(self.__dict__[k])
return self
@staticmethod
def empty(batch: 'Batch') -> 'Batch':
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
``None`` filled, the shape is the same as the given
:class:`~tianshou.data.Batch`.
"""
batch = Batch(**batch)
batch.empty_()
return batch
def __len__(self) -> int:
"""Return len(self)."""
r = []
@ -534,21 +582,20 @@ class Batch:
return min(r)
@property
def size(self) -> int:
"""Return self.size."""
def shape(self) -> List[int]:
"""Return self.shape."""
if len(self.__dict__.keys()) == 0:
return 0
return []
else:
r = []
data_shape = []
for v in self.__dict__.values():
if isinstance(v, Batch):
r.append(v.size)
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
r.append(len(v))
else:
r.append(1)
return min(r) if len(r) > 0 else 0
try:
data_shape.append(v.shape)
except AttributeError:
raise TypeError("No support for 'shape' method with "
f"type {type(v)} in class Batch.")
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
else data_shape[0]
def split(self, size: Optional[int] = None,
shuffle: bool = True) -> Iterator['Batch']:

View File

@ -200,8 +200,10 @@ class Collector(object):
return
if isinstance(self.state, list):
self.state[id] = None
elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)):
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
self.state[id] *= 0
else: # Batch
self.state[id].empty_()
def collect(self,
n_step: int = 0,