change Batch.empty to in-place fill; add copy option for Batch construction (#110)
* in-place empty_ for Batch * change Batch.empty to in-place fill; add copy option for Batch construction * type signiture & remove shadow names for copy * add doc for data type (only support numbers and object data type) * add unit test for Batch copy * fix pep8 * add test case for Batch.empty * doc fix * fix pep8 * use object to test Batch * test commit * refact * change Batch(copy) testcase * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
5b1373924e
commit
8913bf36b1
@ -116,7 +116,7 @@ def test_batch_over_batch():
|
||||
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
|
||||
|
||||
|
||||
def test_batch_cat_and_stack_and_empty():
|
||||
def test_batch_cat_and_stack():
|
||||
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
||||
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
||||
b12_cat_out = Batch.cat((b1, b2))
|
||||
@ -145,24 +145,6 @@ def test_batch_cat_and_stack_and_empty():
|
||||
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
|
||||
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
||||
assert b5.b.d[1] == 0.0
|
||||
b5[1] = Batch.empty(b5[0])
|
||||
assert np.allclose(b5.a, [False, False])
|
||||
assert np.allclose(b5.b.c, [2, 0])
|
||||
assert np.allclose(b5.b.d, [1, 0])
|
||||
data = Batch(a=[False, True],
|
||||
b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]},
|
||||
c=np.array([1, 3, 4], dtype=np.int),
|
||||
t=torch.tensor([4, 5, 6, 7.]))
|
||||
data[-1] = Batch.empty(data[1])
|
||||
assert np.allclose(data.c, [1, 3, 0])
|
||||
assert np.allclose(data.a, [False, False])
|
||||
assert list(data.b.c) == ['2.0', '']
|
||||
assert list(data.b.d) == [1, None]
|
||||
assert np.allclose(data.b.e, [2, 0])
|
||||
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
|
||||
b0 = Batch()
|
||||
b0.empty_()
|
||||
assert b0.shape == []
|
||||
|
||||
|
||||
def test_batch_over_batch_to_torch():
|
||||
@ -225,6 +207,71 @@ def test_batch_from_to_numpy_without_copy():
|
||||
assert c_mem_addr_new == c_mem_addr_orig
|
||||
|
||||
|
||||
def test_batch_copy():
|
||||
batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6]))
|
||||
batch2 = Batch({'c': np.array([6, 7, 8]), 'b': batch})
|
||||
orig_c_addr = batch2.c.__array_interface__['data'][0]
|
||||
orig_b_a_addr = batch2.b.a.__array_interface__['data'][0]
|
||||
orig_b_b_addr = batch2.b.b.__array_interface__['data'][0]
|
||||
# test with copy=False
|
||||
batch3 = Batch(copy=False, **batch2)
|
||||
curr_c_addr = batch3.c.__array_interface__['data'][0]
|
||||
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
|
||||
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
|
||||
assert batch2.c is batch3.c
|
||||
assert batch2.b is batch3.b
|
||||
assert batch2.b.a is batch3.b.a
|
||||
assert batch2.b.b is batch3.b.b
|
||||
assert orig_c_addr == curr_c_addr
|
||||
assert orig_b_a_addr == curr_b_a_addr
|
||||
assert orig_b_b_addr == curr_b_b_addr
|
||||
# test with copy=True
|
||||
batch3 = Batch(copy=True, **batch2)
|
||||
curr_c_addr = batch3.c.__array_interface__['data'][0]
|
||||
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
|
||||
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
|
||||
assert batch2.c is not batch3.c
|
||||
assert batch2.b is not batch3.b
|
||||
assert batch2.b.a is not batch3.b.a
|
||||
assert batch2.b.b is not batch3.b.b
|
||||
assert orig_c_addr != curr_c_addr
|
||||
assert orig_b_a_addr != curr_b_a_addr
|
||||
assert orig_b_b_addr != curr_b_b_addr
|
||||
|
||||
|
||||
def test_batch_empty():
|
||||
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
|
||||
{'a': True, 'b': {'c': 3.0}}])
|
||||
b5 = Batch(b5_dict)
|
||||
b5[1] = Batch.empty(b5[0])
|
||||
assert np.allclose(b5.a, [False, False])
|
||||
assert np.allclose(b5.b.c, [2, 0])
|
||||
assert np.allclose(b5.b.d, [1, 0])
|
||||
data = Batch(a=[False, True],
|
||||
b={'c': np.array([2., 'st'], dtype=np.object),
|
||||
'd': [1, None],
|
||||
'e': [2., float('nan')]},
|
||||
c=np.array([1, 3, 4], dtype=np.int),
|
||||
t=torch.tensor([4, 5, 6, 7.]))
|
||||
data[-1] = Batch.empty(data[1])
|
||||
assert np.allclose(data.c, [1, 3, 0])
|
||||
assert np.allclose(data.a, [False, False])
|
||||
assert list(data.b.c) == [2.0, None]
|
||||
assert list(data.b.d) == [1, None]
|
||||
assert np.allclose(data.b.e, [2, 0])
|
||||
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
|
||||
data[0].empty_() # which will fail in a, b.c, b.d, b.e, c
|
||||
assert torch.allclose(data.t, torch.tensor([0., 5, 6, 0]))
|
||||
data.empty_(index=0)
|
||||
assert np.allclose(data.c, [0, 3, 0])
|
||||
assert list(data.b.c) == [None, None]
|
||||
assert list(data.b.d) == [None, None]
|
||||
assert list(data.b.e) == [0, 0]
|
||||
b0 = Batch()
|
||||
b0.empty_()
|
||||
assert b0.shape == []
|
||||
|
||||
|
||||
def test_batch_numpy_compatibility():
|
||||
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
|
||||
b=Batch(),
|
||||
@ -246,4 +293,6 @@ if __name__ == '__main__':
|
||||
test_batch_pickle()
|
||||
test_batch_from_to_numpy_without_copy()
|
||||
test_batch_numpy_compatibility()
|
||||
test_batch_cat_and_stack_and_empty()
|
||||
test_batch_cat_and_stack()
|
||||
test_batch_copy()
|
||||
test_batch_empty()
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import copy
|
||||
import pprint
|
||||
import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from numbers import Number
|
||||
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||
@ -85,8 +85,13 @@ class Batch:
|
||||
c: '2312312',
|
||||
)
|
||||
|
||||
In short, you can define a :class:`Batch` with any key-value pair. The
|
||||
current implementation of Tianshou typically use 7 reserved keys in
|
||||
In short, you can define a :class:`Batch` with any key-value pair.
|
||||
|
||||
For Numpy arrays, only data types with ``np.object`` and numbers are
|
||||
supported. For strings or other data types, however, they can be held
|
||||
in ``np.object`` arrays.
|
||||
|
||||
The current implementation of Tianshou typically use 7 reserved keys in
|
||||
:class:`~tianshou.data.Batch`:
|
||||
|
||||
* ``obs`` the observation of step :math:`t` ;
|
||||
@ -252,7 +257,10 @@ class Batch:
|
||||
batch_dict: Optional[Union[
|
||||
dict, 'Batch', Tuple[Union[dict, 'Batch']],
|
||||
List[Union[dict, 'Batch']], np.ndarray]] = None,
|
||||
copy: bool = False,
|
||||
**kwargs) -> None:
|
||||
if copy:
|
||||
batch_dict = deepcopy(batch_dict)
|
||||
if _is_batch_set(batch_dict):
|
||||
self.stack_(batch_dict)
|
||||
elif isinstance(batch_dict, (dict, Batch)):
|
||||
@ -264,7 +272,7 @@ class Batch:
|
||||
v = np.array(v)
|
||||
self.__dict__[k] = v
|
||||
if len(kwargs) > 0:
|
||||
self.__init__(kwargs)
|
||||
self.__init__(kwargs, copy=copy)
|
||||
|
||||
def __setattr__(self, key: str, value: Any):
|
||||
"""self[key] = value"""
|
||||
@ -360,7 +368,7 @@ class Batch:
|
||||
def __add__(self, other: Union['Batch', Number]):
|
||||
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
||||
instance out-of-place."""
|
||||
return copy.deepcopy(self).__iadd__(other)
|
||||
return deepcopy(self).__iadd__(other)
|
||||
|
||||
def __imul__(self, val: Number):
|
||||
"""Algebraic multiplication with a scalar value in-place."""
|
||||
@ -372,7 +380,7 @@ class Batch:
|
||||
|
||||
def __mul__(self, val: Number):
|
||||
"""Algebraic multiplication with a scalar value out-of-place."""
|
||||
return copy.deepcopy(self).__imul__(val)
|
||||
return deepcopy(self).__imul__(val)
|
||||
|
||||
def __itruediv__(self, val: Number):
|
||||
"""Algebraic division wibyth a scalar value in-place."""
|
||||
@ -384,7 +392,7 @@ class Batch:
|
||||
|
||||
def __truediv__(self, val: Number):
|
||||
"""Algebraic division wibyth a scalar value out-of-place."""
|
||||
return copy.deepcopy(self).__itruediv__(val)
|
||||
return deepcopy(self).__itruediv__(val)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return str(self)."""
|
||||
@ -476,7 +484,7 @@ class Batch:
|
||||
if v is None:
|
||||
continue
|
||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||
self.__dict__[k] = copy.deepcopy(v)
|
||||
self.__dict__[k] = deepcopy(v)
|
||||
elif isinstance(v, np.ndarray) and v.ndim > 0:
|
||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||
elif isinstance(v, torch.Tensor):
|
||||
@ -537,34 +545,45 @@ class Batch:
|
||||
batch.stack_(batches, axis)
|
||||
return batch
|
||||
|
||||
def empty_(self) -> 'Batch':
|
||||
def empty_(self, index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]] = None
|
||||
) -> 'Batch':
|
||||
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
|
||||
``None`` filled.
|
||||
``None`` filled. If ``index`` is specified, it will only reset the
|
||||
specific indexed-data.
|
||||
"""
|
||||
for k, v in self.items():
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(v, Batch):
|
||||
self.__dict__[k].empty_()
|
||||
elif isinstance(v, np.ndarray) and v.dtype == np.object:
|
||||
self.__dict__[k].fill(None)
|
||||
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly
|
||||
self.__dict__[k] = torch.zeros_like(self.__dict__[k])
|
||||
else: # np
|
||||
self.__dict__[k] *= 0
|
||||
if hasattr(v, 'dtype') and v.dtype.kind in 'fc':
|
||||
self.__dict__[k] = np.nan_to_num(self.__dict__[k])
|
||||
self.__dict__[k].empty_(index=index)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self.__dict__[k][index] = 0
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.dtype == np.object:
|
||||
self.__dict__[k][index] = None
|
||||
else:
|
||||
self.__dict__[k][index] = 0
|
||||
else: # scalar value
|
||||
warnings.warn('You are calling Batch.empty on a NumPy scalar, '
|
||||
'which may cause undefined behaviors.')
|
||||
if isinstance(v, (np.generic, Number)):
|
||||
self.__dict__[k] *= 0
|
||||
if np.isnan(self.__dict__[k]):
|
||||
self.__dict__[k] = 0
|
||||
else:
|
||||
self.__dict__[k] = None
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def empty(batch: 'Batch') -> 'Batch':
|
||||
def empty(batch: 'Batch', index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]] = None
|
||||
) -> 'Batch':
|
||||
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
|
||||
``None`` filled, the shape is the same as the given
|
||||
:class:`~tianshou.data.Batch`.
|
||||
"""
|
||||
batch = Batch(**batch)
|
||||
batch.empty_()
|
||||
return batch
|
||||
return deepcopy(batch).empty_(index)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
|
@ -200,10 +200,15 @@ class Collector(object):
|
||||
return
|
||||
if isinstance(self.state, list):
|
||||
self.state[id] = None
|
||||
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
|
||||
self.state[id] *= 0
|
||||
else: # Batch
|
||||
self.state[id].empty_()
|
||||
elif isinstance(self.state, torch.Tensor):
|
||||
self.state[id].zero_()
|
||||
elif isinstance(self.state, np.ndarray):
|
||||
if isinstance(self.state.dtype == np.object):
|
||||
self.state[id] = None
|
||||
else:
|
||||
self.state[id] = 0
|
||||
elif isinstance(self.state, Batch):
|
||||
self.state.empty_(id)
|
||||
|
||||
def collect(self,
|
||||
n_step: int = 0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user