change Batch.empty to in-place fill; add copy option for Batch construction (#110)

* in-place empty_ for Batch

* change Batch.empty to in-place fill; add copy option for Batch construction

* type signiture & remove shadow names for copy

* add doc for data type (only support numbers and object data type)

* add unit test for Batch copy

* fix pep8

* add test case for Batch.empty

* doc fix

* fix pep8

* use object to test Batch

* test commit

* refact

* change Batch(copy) testcase

* minor fix

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
youkaichao 2020-07-06 20:30:15 +08:00 committed by GitHub
parent 5b1373924e
commit 8913bf36b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 120 additions and 47 deletions

View File

@ -116,7 +116,7 @@ def test_batch_over_batch():
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
def test_batch_cat_and_stack_and_empty():
def test_batch_cat_and_stack():
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat((b1, b2))
@ -145,24 +145,6 @@ def test_batch_cat_and_stack_and_empty():
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == ['2.0', '']
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
b0 = Batch()
b0.empty_()
assert b0.shape == []
def test_batch_over_batch_to_torch():
@ -225,6 +207,71 @@ def test_batch_from_to_numpy_without_copy():
assert c_mem_addr_new == c_mem_addr_orig
def test_batch_copy():
batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6]))
batch2 = Batch({'c': np.array([6, 7, 8]), 'b': batch})
orig_c_addr = batch2.c.__array_interface__['data'][0]
orig_b_a_addr = batch2.b.a.__array_interface__['data'][0]
orig_b_b_addr = batch2.b.b.__array_interface__['data'][0]
# test with copy=False
batch3 = Batch(copy=False, **batch2)
curr_c_addr = batch3.c.__array_interface__['data'][0]
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
assert batch2.c is batch3.c
assert batch2.b is batch3.b
assert batch2.b.a is batch3.b.a
assert batch2.b.b is batch3.b.b
assert orig_c_addr == curr_c_addr
assert orig_b_a_addr == curr_b_a_addr
assert orig_b_b_addr == curr_b_b_addr
# test with copy=True
batch3 = Batch(copy=True, **batch2)
curr_c_addr = batch3.c.__array_interface__['data'][0]
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
assert batch2.c is not batch3.c
assert batch2.b is not batch3.b
assert batch2.b.a is not batch3.b.a
assert batch2.b.b is not batch3.b.b
assert orig_c_addr != curr_c_addr
assert orig_b_a_addr != curr_b_a_addr
assert orig_b_b_addr != curr_b_b_addr
def test_batch_empty():
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
{'a': True, 'b': {'c': 3.0}}])
b5 = Batch(b5_dict)
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': np.array([2., 'st'], dtype=np.object),
'd': [1, None],
'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == [2.0, None]
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
data[0].empty_() # which will fail in a, b.c, b.d, b.e, c
assert torch.allclose(data.t, torch.tensor([0., 5, 6, 0]))
data.empty_(index=0)
assert np.allclose(data.c, [0, 3, 0])
assert list(data.b.c) == [None, None]
assert list(data.b.d) == [None, None]
assert list(data.b.e) == [0, 0]
b0 = Batch()
b0.empty_()
assert b0.shape == []
def test_batch_numpy_compatibility():
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
b=Batch(),
@ -246,4 +293,6 @@ if __name__ == '__main__':
test_batch_pickle()
test_batch_from_to_numpy_without_copy()
test_batch_numpy_compatibility()
test_batch_cat_and_stack_and_empty()
test_batch_cat_and_stack()
test_batch_copy()
test_batch_empty()

View File

@ -1,8 +1,8 @@
import torch
import copy
import pprint
import warnings
import numpy as np
from copy import deepcopy
from functools import reduce
from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional
@ -85,8 +85,13 @@ class Batch:
c: '2312312',
)
In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 7 reserved keys in
In short, you can define a :class:`Batch` with any key-value pair.
For Numpy arrays, only data types with ``np.object`` and numbers are
supported. For strings or other data types, however, they can be held
in ``np.object`` arrays.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
@ -252,7 +257,10 @@ class Batch:
batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None,
copy: bool = False,
**kwargs) -> None:
if copy:
batch_dict = deepcopy(batch_dict)
if _is_batch_set(batch_dict):
self.stack_(batch_dict)
elif isinstance(batch_dict, (dict, Batch)):
@ -264,7 +272,7 @@ class Batch:
v = np.array(v)
self.__dict__[k] = v
if len(kwargs) > 0:
self.__init__(kwargs)
self.__init__(kwargs, copy=copy)
def __setattr__(self, key: str, value: Any):
"""self[key] = value"""
@ -360,7 +368,7 @@ class Batch:
def __add__(self, other: Union['Batch', Number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance out-of-place."""
return copy.deepcopy(self).__iadd__(other)
return deepcopy(self).__iadd__(other)
def __imul__(self, val: Number):
"""Algebraic multiplication with a scalar value in-place."""
@ -372,7 +380,7 @@ class Batch:
def __mul__(self, val: Number):
"""Algebraic multiplication with a scalar value out-of-place."""
return copy.deepcopy(self).__imul__(val)
return deepcopy(self).__imul__(val)
def __itruediv__(self, val: Number):
"""Algebraic division wibyth a scalar value in-place."""
@ -384,7 +392,7 @@ class Batch:
def __truediv__(self, val: Number):
"""Algebraic division wibyth a scalar value out-of-place."""
return copy.deepcopy(self).__itruediv__(val)
return deepcopy(self).__itruediv__(val)
def __repr__(self) -> str:
"""Return str(self)."""
@ -476,7 +484,7 @@ class Batch:
if v is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = copy.deepcopy(v)
self.__dict__[k] = deepcopy(v)
elif isinstance(v, np.ndarray) and v.ndim > 0:
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor):
@ -537,34 +545,45 @@ class Batch:
batch.stack_(batches, axis)
return batch
def empty_(self) -> 'Batch':
def empty_(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
``None`` filled.
``None`` filled. If ``index`` is specified, it will only reset the
specific indexed-data.
"""
for k, v in self.items():
if v is None:
continue
if isinstance(v, Batch):
self.__dict__[k].empty_()
elif isinstance(v, np.ndarray) and v.dtype == np.object:
self.__dict__[k].fill(None)
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly
self.__dict__[k] = torch.zeros_like(self.__dict__[k])
else: # np
self.__dict__[k].empty_(index=index)
elif isinstance(v, torch.Tensor):
self.__dict__[k][index] = 0
elif isinstance(v, np.ndarray):
if v.dtype == np.object:
self.__dict__[k][index] = None
else:
self.__dict__[k][index] = 0
else: # scalar value
warnings.warn('You are calling Batch.empty on a NumPy scalar, '
'which may cause undefined behaviors.')
if isinstance(v, (np.generic, Number)):
self.__dict__[k] *= 0
if hasattr(v, 'dtype') and v.dtype.kind in 'fc':
self.__dict__[k] = np.nan_to_num(self.__dict__[k])
if np.isnan(self.__dict__[k]):
self.__dict__[k] = 0
else:
self.__dict__[k] = None
return self
@staticmethod
def empty(batch: 'Batch') -> 'Batch':
def empty(batch: 'Batch', index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
``None`` filled, the shape is the same as the given
:class:`~tianshou.data.Batch`.
"""
batch = Batch(**batch)
batch.empty_()
return batch
return deepcopy(batch).empty_(index)
def __len__(self) -> int:
"""Return len(self)."""

View File

@ -200,10 +200,15 @@ class Collector(object):
return
if isinstance(self.state, list):
self.state[id] = None
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
self.state[id] *= 0
else: # Batch
self.state[id].empty_()
elif isinstance(self.state, torch.Tensor):
self.state[id].zero_()
elif isinstance(self.state, np.ndarray):
if isinstance(self.state.dtype == np.object):
self.state[id] = None
else:
self.state[id] = 0
elif isinstance(self.state, Batch):
self.state.empty_(id)
def collect(self,
n_step: int = 0,