change Batch.empty to in-place fill; add copy option for Batch construction (#110)

* in-place empty_ for Batch

* change Batch.empty to in-place fill; add copy option for Batch construction

* type signiture & remove shadow names for copy

* add doc for data type (only support numbers and object data type)

* add unit test for Batch copy

* fix pep8

* add test case for Batch.empty

* doc fix

* fix pep8

* use object to test Batch

* test commit

* refact

* change Batch(copy) testcase

* minor fix

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
youkaichao 2020-07-06 20:30:15 +08:00 committed by GitHub
parent 5b1373924e
commit 8913bf36b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 120 additions and 47 deletions

View File

@ -116,7 +116,7 @@ def test_batch_over_batch():
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
def test_batch_cat_and_stack_and_empty(): def test_batch_cat_and_stack():
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat((b1, b2)) b12_cat_out = Batch.cat((b1, b2))
@ -145,24 +145,6 @@ def test_batch_cat_and_stack_and_empty():
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0 assert b5.b.d[1] == 0.0
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == ['2.0', '']
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
b0 = Batch()
b0.empty_()
assert b0.shape == []
def test_batch_over_batch_to_torch(): def test_batch_over_batch_to_torch():
@ -225,6 +207,71 @@ def test_batch_from_to_numpy_without_copy():
assert c_mem_addr_new == c_mem_addr_orig assert c_mem_addr_new == c_mem_addr_orig
def test_batch_copy():
batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6]))
batch2 = Batch({'c': np.array([6, 7, 8]), 'b': batch})
orig_c_addr = batch2.c.__array_interface__['data'][0]
orig_b_a_addr = batch2.b.a.__array_interface__['data'][0]
orig_b_b_addr = batch2.b.b.__array_interface__['data'][0]
# test with copy=False
batch3 = Batch(copy=False, **batch2)
curr_c_addr = batch3.c.__array_interface__['data'][0]
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
assert batch2.c is batch3.c
assert batch2.b is batch3.b
assert batch2.b.a is batch3.b.a
assert batch2.b.b is batch3.b.b
assert orig_c_addr == curr_c_addr
assert orig_b_a_addr == curr_b_a_addr
assert orig_b_b_addr == curr_b_b_addr
# test with copy=True
batch3 = Batch(copy=True, **batch2)
curr_c_addr = batch3.c.__array_interface__['data'][0]
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
assert batch2.c is not batch3.c
assert batch2.b is not batch3.b
assert batch2.b.a is not batch3.b.a
assert batch2.b.b is not batch3.b.b
assert orig_c_addr != curr_c_addr
assert orig_b_a_addr != curr_b_a_addr
assert orig_b_b_addr != curr_b_b_addr
def test_batch_empty():
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
{'a': True, 'b': {'c': 3.0}}])
b5 = Batch(b5_dict)
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': np.array([2., 'st'], dtype=np.object),
'd': [1, None],
'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == [2.0, None]
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
data[0].empty_() # which will fail in a, b.c, b.d, b.e, c
assert torch.allclose(data.t, torch.tensor([0., 5, 6, 0]))
data.empty_(index=0)
assert np.allclose(data.c, [0, 3, 0])
assert list(data.b.c) == [None, None]
assert list(data.b.d) == [None, None]
assert list(data.b.e) == [0, 0]
b0 = Batch()
b0.empty_()
assert b0.shape == []
def test_batch_numpy_compatibility(): def test_batch_numpy_compatibility():
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
b=Batch(), b=Batch(),
@ -246,4 +293,6 @@ if __name__ == '__main__':
test_batch_pickle() test_batch_pickle()
test_batch_from_to_numpy_without_copy() test_batch_from_to_numpy_without_copy()
test_batch_numpy_compatibility() test_batch_numpy_compatibility()
test_batch_cat_and_stack_and_empty() test_batch_cat_and_stack()
test_batch_copy()
test_batch_empty()

View File

@ -1,8 +1,8 @@
import torch import torch
import copy
import pprint import pprint
import warnings import warnings
import numpy as np import numpy as np
from copy import deepcopy
from functools import reduce from functools import reduce
from numbers import Number from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional from typing import Any, List, Tuple, Union, Iterator, Optional
@ -85,8 +85,13 @@ class Batch:
c: '2312312', c: '2312312',
) )
In short, you can define a :class:`Batch` with any key-value pair. The In short, you can define a :class:`Batch` with any key-value pair.
current implementation of Tianshou typically use 7 reserved keys in
For Numpy arrays, only data types with ``np.object`` and numbers are
supported. For strings or other data types, however, they can be held
in ``np.object`` arrays.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`: :class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ; * ``obs`` the observation of step :math:`t` ;
@ -252,7 +257,10 @@ class Batch:
batch_dict: Optional[Union[ batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']], dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None, List[Union[dict, 'Batch']], np.ndarray]] = None,
copy: bool = False,
**kwargs) -> None: **kwargs) -> None:
if copy:
batch_dict = deepcopy(batch_dict)
if _is_batch_set(batch_dict): if _is_batch_set(batch_dict):
self.stack_(batch_dict) self.stack_(batch_dict)
elif isinstance(batch_dict, (dict, Batch)): elif isinstance(batch_dict, (dict, Batch)):
@ -264,7 +272,7 @@ class Batch:
v = np.array(v) v = np.array(v)
self.__dict__[k] = v self.__dict__[k] = v
if len(kwargs) > 0: if len(kwargs) > 0:
self.__init__(kwargs) self.__init__(kwargs, copy=copy)
def __setattr__(self, key: str, value: Any): def __setattr__(self, key: str, value: Any):
"""self[key] = value""" """self[key] = value"""
@ -360,7 +368,7 @@ class Batch:
def __add__(self, other: Union['Batch', Number]): def __add__(self, other: Union['Batch', Number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch` """Algebraic addition with another :class:`~tianshou.data.Batch`
instance out-of-place.""" instance out-of-place."""
return copy.deepcopy(self).__iadd__(other) return deepcopy(self).__iadd__(other)
def __imul__(self, val: Number): def __imul__(self, val: Number):
"""Algebraic multiplication with a scalar value in-place.""" """Algebraic multiplication with a scalar value in-place."""
@ -372,7 +380,7 @@ class Batch:
def __mul__(self, val: Number): def __mul__(self, val: Number):
"""Algebraic multiplication with a scalar value out-of-place.""" """Algebraic multiplication with a scalar value out-of-place."""
return copy.deepcopy(self).__imul__(val) return deepcopy(self).__imul__(val)
def __itruediv__(self, val: Number): def __itruediv__(self, val: Number):
"""Algebraic division wibyth a scalar value in-place.""" """Algebraic division wibyth a scalar value in-place."""
@ -384,7 +392,7 @@ class Batch:
def __truediv__(self, val: Number): def __truediv__(self, val: Number):
"""Algebraic division wibyth a scalar value out-of-place.""" """Algebraic division wibyth a scalar value out-of-place."""
return copy.deepcopy(self).__itruediv__(val) return deepcopy(self).__itruediv__(val)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return str(self).""" """Return str(self)."""
@ -476,7 +484,7 @@ class Batch:
if v is None: if v is None:
continue continue
if not hasattr(self, k) or self.__dict__[k] is None: if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = copy.deepcopy(v) self.__dict__[k] = deepcopy(v)
elif isinstance(v, np.ndarray) and v.ndim > 0: elif isinstance(v, np.ndarray) and v.ndim > 0:
self.__dict__[k] = np.concatenate([self.__dict__[k], v]) self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
@ -537,34 +545,45 @@ class Batch:
batch.stack_(batches, axis) batch.stack_(batches, axis)
return batch return batch
def empty_(self) -> 'Batch': def empty_(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or """Return an empty a :class:`~tianshou.data.Batch` object with 0 or
``None`` filled. ``None`` filled. If ``index`` is specified, it will only reset the
specific indexed-data.
""" """
for k, v in self.items(): for k, v in self.items():
if v is None: if v is None:
continue continue
if isinstance(v, Batch): if isinstance(v, Batch):
self.__dict__[k].empty_() self.__dict__[k].empty_(index=index)
elif isinstance(v, np.ndarray) and v.dtype == np.object: elif isinstance(v, torch.Tensor):
self.__dict__[k].fill(None) self.__dict__[k][index] = 0
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly elif isinstance(v, np.ndarray):
self.__dict__[k] = torch.zeros_like(self.__dict__[k]) if v.dtype == np.object:
else: # np self.__dict__[k][index] = None
self.__dict__[k] *= 0 else:
if hasattr(v, 'dtype') and v.dtype.kind in 'fc': self.__dict__[k][index] = 0
self.__dict__[k] = np.nan_to_num(self.__dict__[k]) else: # scalar value
warnings.warn('You are calling Batch.empty on a NumPy scalar, '
'which may cause undefined behaviors.')
if isinstance(v, (np.generic, Number)):
self.__dict__[k] *= 0
if np.isnan(self.__dict__[k]):
self.__dict__[k] = 0
else:
self.__dict__[k] = None
return self return self
@staticmethod @staticmethod
def empty(batch: 'Batch') -> 'Batch': def empty(batch: 'Batch', index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or """Return an empty :class:`~tianshou.data.Batch` object with 0 or
``None`` filled, the shape is the same as the given ``None`` filled, the shape is the same as the given
:class:`~tianshou.data.Batch`. :class:`~tianshou.data.Batch`.
""" """
batch = Batch(**batch) return deepcopy(batch).empty_(index)
batch.empty_()
return batch
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self).""" """Return len(self)."""

View File

@ -200,10 +200,15 @@ class Collector(object):
return return
if isinstance(self.state, list): if isinstance(self.state, list):
self.state[id] = None self.state[id] = None
elif isinstance(self.state, (torch.Tensor, np.ndarray)): elif isinstance(self.state, torch.Tensor):
self.state[id] *= 0 self.state[id].zero_()
else: # Batch elif isinstance(self.state, np.ndarray):
self.state[id].empty_() if isinstance(self.state.dtype == np.object):
self.state[id] = None
else:
self.state[id] = 0
elif isinstance(self.state, Batch):
self.state.empty_(id)
def collect(self, def collect(self,
n_step: int = 0, n_step: int = 0,