change Batch.empty to in-place fill; add copy option for Batch construction (#110)
* in-place empty_ for Batch * change Batch.empty to in-place fill; add copy option for Batch construction * type signiture & remove shadow names for copy * add doc for data type (only support numbers and object data type) * add unit test for Batch copy * fix pep8 * add test case for Batch.empty * doc fix * fix pep8 * use object to test Batch * test commit * refact * change Batch(copy) testcase * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
5b1373924e
commit
8913bf36b1
@ -116,7 +116,7 @@ def test_batch_over_batch():
|
|||||||
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
|
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_cat_and_stack_and_empty():
|
def test_batch_cat_and_stack():
|
||||||
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
||||||
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
||||||
b12_cat_out = Batch.cat((b1, b2))
|
b12_cat_out = Batch.cat((b1, b2))
|
||||||
@ -145,24 +145,6 @@ def test_batch_cat_and_stack_and_empty():
|
|||||||
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
|
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
|
||||||
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
assert b5.b.d[0] == b5_dict[0]['b']['d']
|
||||||
assert b5.b.d[1] == 0.0
|
assert b5.b.d[1] == 0.0
|
||||||
b5[1] = Batch.empty(b5[0])
|
|
||||||
assert np.allclose(b5.a, [False, False])
|
|
||||||
assert np.allclose(b5.b.c, [2, 0])
|
|
||||||
assert np.allclose(b5.b.d, [1, 0])
|
|
||||||
data = Batch(a=[False, True],
|
|
||||||
b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]},
|
|
||||||
c=np.array([1, 3, 4], dtype=np.int),
|
|
||||||
t=torch.tensor([4, 5, 6, 7.]))
|
|
||||||
data[-1] = Batch.empty(data[1])
|
|
||||||
assert np.allclose(data.c, [1, 3, 0])
|
|
||||||
assert np.allclose(data.a, [False, False])
|
|
||||||
assert list(data.b.c) == ['2.0', '']
|
|
||||||
assert list(data.b.d) == [1, None]
|
|
||||||
assert np.allclose(data.b.e, [2, 0])
|
|
||||||
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
|
|
||||||
b0 = Batch()
|
|
||||||
b0.empty_()
|
|
||||||
assert b0.shape == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch_to_torch():
|
def test_batch_over_batch_to_torch():
|
||||||
@ -225,6 +207,71 @@ def test_batch_from_to_numpy_without_copy():
|
|||||||
assert c_mem_addr_new == c_mem_addr_orig
|
assert c_mem_addr_new == c_mem_addr_orig
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_copy():
|
||||||
|
batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6]))
|
||||||
|
batch2 = Batch({'c': np.array([6, 7, 8]), 'b': batch})
|
||||||
|
orig_c_addr = batch2.c.__array_interface__['data'][0]
|
||||||
|
orig_b_a_addr = batch2.b.a.__array_interface__['data'][0]
|
||||||
|
orig_b_b_addr = batch2.b.b.__array_interface__['data'][0]
|
||||||
|
# test with copy=False
|
||||||
|
batch3 = Batch(copy=False, **batch2)
|
||||||
|
curr_c_addr = batch3.c.__array_interface__['data'][0]
|
||||||
|
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
|
||||||
|
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
|
||||||
|
assert batch2.c is batch3.c
|
||||||
|
assert batch2.b is batch3.b
|
||||||
|
assert batch2.b.a is batch3.b.a
|
||||||
|
assert batch2.b.b is batch3.b.b
|
||||||
|
assert orig_c_addr == curr_c_addr
|
||||||
|
assert orig_b_a_addr == curr_b_a_addr
|
||||||
|
assert orig_b_b_addr == curr_b_b_addr
|
||||||
|
# test with copy=True
|
||||||
|
batch3 = Batch(copy=True, **batch2)
|
||||||
|
curr_c_addr = batch3.c.__array_interface__['data'][0]
|
||||||
|
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
|
||||||
|
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
|
||||||
|
assert batch2.c is not batch3.c
|
||||||
|
assert batch2.b is not batch3.b
|
||||||
|
assert batch2.b.a is not batch3.b.a
|
||||||
|
assert batch2.b.b is not batch3.b.b
|
||||||
|
assert orig_c_addr != curr_c_addr
|
||||||
|
assert orig_b_a_addr != curr_b_a_addr
|
||||||
|
assert orig_b_b_addr != curr_b_b_addr
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_empty():
|
||||||
|
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
|
||||||
|
{'a': True, 'b': {'c': 3.0}}])
|
||||||
|
b5 = Batch(b5_dict)
|
||||||
|
b5[1] = Batch.empty(b5[0])
|
||||||
|
assert np.allclose(b5.a, [False, False])
|
||||||
|
assert np.allclose(b5.b.c, [2, 0])
|
||||||
|
assert np.allclose(b5.b.d, [1, 0])
|
||||||
|
data = Batch(a=[False, True],
|
||||||
|
b={'c': np.array([2., 'st'], dtype=np.object),
|
||||||
|
'd': [1, None],
|
||||||
|
'e': [2., float('nan')]},
|
||||||
|
c=np.array([1, 3, 4], dtype=np.int),
|
||||||
|
t=torch.tensor([4, 5, 6, 7.]))
|
||||||
|
data[-1] = Batch.empty(data[1])
|
||||||
|
assert np.allclose(data.c, [1, 3, 0])
|
||||||
|
assert np.allclose(data.a, [False, False])
|
||||||
|
assert list(data.b.c) == [2.0, None]
|
||||||
|
assert list(data.b.d) == [1, None]
|
||||||
|
assert np.allclose(data.b.e, [2, 0])
|
||||||
|
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
|
||||||
|
data[0].empty_() # which will fail in a, b.c, b.d, b.e, c
|
||||||
|
assert torch.allclose(data.t, torch.tensor([0., 5, 6, 0]))
|
||||||
|
data.empty_(index=0)
|
||||||
|
assert np.allclose(data.c, [0, 3, 0])
|
||||||
|
assert list(data.b.c) == [None, None]
|
||||||
|
assert list(data.b.d) == [None, None]
|
||||||
|
assert list(data.b.e) == [0, 0]
|
||||||
|
b0 = Batch()
|
||||||
|
b0.empty_()
|
||||||
|
assert b0.shape == []
|
||||||
|
|
||||||
|
|
||||||
def test_batch_numpy_compatibility():
|
def test_batch_numpy_compatibility():
|
||||||
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
|
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
|
||||||
b=Batch(),
|
b=Batch(),
|
||||||
@ -246,4 +293,6 @@ if __name__ == '__main__':
|
|||||||
test_batch_pickle()
|
test_batch_pickle()
|
||||||
test_batch_from_to_numpy_without_copy()
|
test_batch_from_to_numpy_without_copy()
|
||||||
test_batch_numpy_compatibility()
|
test_batch_numpy_compatibility()
|
||||||
test_batch_cat_and_stack_and_empty()
|
test_batch_cat_and_stack()
|
||||||
|
test_batch_copy()
|
||||||
|
test_batch_empty()
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import copy
|
|
||||||
import pprint
|
import pprint
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Any, List, Tuple, Union, Iterator, Optional
|
from typing import Any, List, Tuple, Union, Iterator, Optional
|
||||||
@ -85,8 +85,13 @@ class Batch:
|
|||||||
c: '2312312',
|
c: '2312312',
|
||||||
)
|
)
|
||||||
|
|
||||||
In short, you can define a :class:`Batch` with any key-value pair. The
|
In short, you can define a :class:`Batch` with any key-value pair.
|
||||||
current implementation of Tianshou typically use 7 reserved keys in
|
|
||||||
|
For Numpy arrays, only data types with ``np.object`` and numbers are
|
||||||
|
supported. For strings or other data types, however, they can be held
|
||||||
|
in ``np.object`` arrays.
|
||||||
|
|
||||||
|
The current implementation of Tianshou typically use 7 reserved keys in
|
||||||
:class:`~tianshou.data.Batch`:
|
:class:`~tianshou.data.Batch`:
|
||||||
|
|
||||||
* ``obs`` the observation of step :math:`t` ;
|
* ``obs`` the observation of step :math:`t` ;
|
||||||
@ -252,7 +257,10 @@ class Batch:
|
|||||||
batch_dict: Optional[Union[
|
batch_dict: Optional[Union[
|
||||||
dict, 'Batch', Tuple[Union[dict, 'Batch']],
|
dict, 'Batch', Tuple[Union[dict, 'Batch']],
|
||||||
List[Union[dict, 'Batch']], np.ndarray]] = None,
|
List[Union[dict, 'Batch']], np.ndarray]] = None,
|
||||||
|
copy: bool = False,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
|
if copy:
|
||||||
|
batch_dict = deepcopy(batch_dict)
|
||||||
if _is_batch_set(batch_dict):
|
if _is_batch_set(batch_dict):
|
||||||
self.stack_(batch_dict)
|
self.stack_(batch_dict)
|
||||||
elif isinstance(batch_dict, (dict, Batch)):
|
elif isinstance(batch_dict, (dict, Batch)):
|
||||||
@ -264,7 +272,7 @@ class Batch:
|
|||||||
v = np.array(v)
|
v = np.array(v)
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
self.__init__(kwargs)
|
self.__init__(kwargs, copy=copy)
|
||||||
|
|
||||||
def __setattr__(self, key: str, value: Any):
|
def __setattr__(self, key: str, value: Any):
|
||||||
"""self[key] = value"""
|
"""self[key] = value"""
|
||||||
@ -360,7 +368,7 @@ class Batch:
|
|||||||
def __add__(self, other: Union['Batch', Number]):
|
def __add__(self, other: Union['Batch', Number]):
|
||||||
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
"""Algebraic addition with another :class:`~tianshou.data.Batch`
|
||||||
instance out-of-place."""
|
instance out-of-place."""
|
||||||
return copy.deepcopy(self).__iadd__(other)
|
return deepcopy(self).__iadd__(other)
|
||||||
|
|
||||||
def __imul__(self, val: Number):
|
def __imul__(self, val: Number):
|
||||||
"""Algebraic multiplication with a scalar value in-place."""
|
"""Algebraic multiplication with a scalar value in-place."""
|
||||||
@ -372,7 +380,7 @@ class Batch:
|
|||||||
|
|
||||||
def __mul__(self, val: Number):
|
def __mul__(self, val: Number):
|
||||||
"""Algebraic multiplication with a scalar value out-of-place."""
|
"""Algebraic multiplication with a scalar value out-of-place."""
|
||||||
return copy.deepcopy(self).__imul__(val)
|
return deepcopy(self).__imul__(val)
|
||||||
|
|
||||||
def __itruediv__(self, val: Number):
|
def __itruediv__(self, val: Number):
|
||||||
"""Algebraic division wibyth a scalar value in-place."""
|
"""Algebraic division wibyth a scalar value in-place."""
|
||||||
@ -384,7 +392,7 @@ class Batch:
|
|||||||
|
|
||||||
def __truediv__(self, val: Number):
|
def __truediv__(self, val: Number):
|
||||||
"""Algebraic division wibyth a scalar value out-of-place."""
|
"""Algebraic division wibyth a scalar value out-of-place."""
|
||||||
return copy.deepcopy(self).__itruediv__(val)
|
return deepcopy(self).__itruediv__(val)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Return str(self)."""
|
"""Return str(self)."""
|
||||||
@ -476,7 +484,7 @@ class Batch:
|
|||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||||
self.__dict__[k] = copy.deepcopy(v)
|
self.__dict__[k] = deepcopy(v)
|
||||||
elif isinstance(v, np.ndarray) and v.ndim > 0:
|
elif isinstance(v, np.ndarray) and v.ndim > 0:
|
||||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
@ -537,34 +545,45 @@ class Batch:
|
|||||||
batch.stack_(batches, axis)
|
batch.stack_(batches, axis)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def empty_(self) -> 'Batch':
|
def empty_(self, index: Union[
|
||||||
|
str, slice, int, np.integer, np.ndarray, List[int]] = None
|
||||||
|
) -> 'Batch':
|
||||||
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
|
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
|
||||||
``None`` filled.
|
``None`` filled. If ``index`` is specified, it will only reset the
|
||||||
|
specific indexed-data.
|
||||||
"""
|
"""
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
if isinstance(v, Batch):
|
if isinstance(v, Batch):
|
||||||
self.__dict__[k].empty_()
|
self.__dict__[k].empty_(index=index)
|
||||||
elif isinstance(v, np.ndarray) and v.dtype == np.object:
|
elif isinstance(v, torch.Tensor):
|
||||||
self.__dict__[k].fill(None)
|
self.__dict__[k][index] = 0
|
||||||
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly
|
elif isinstance(v, np.ndarray):
|
||||||
self.__dict__[k] = torch.zeros_like(self.__dict__[k])
|
if v.dtype == np.object:
|
||||||
else: # np
|
self.__dict__[k][index] = None
|
||||||
self.__dict__[k] *= 0
|
else:
|
||||||
if hasattr(v, 'dtype') and v.dtype.kind in 'fc':
|
self.__dict__[k][index] = 0
|
||||||
self.__dict__[k] = np.nan_to_num(self.__dict__[k])
|
else: # scalar value
|
||||||
|
warnings.warn('You are calling Batch.empty on a NumPy scalar, '
|
||||||
|
'which may cause undefined behaviors.')
|
||||||
|
if isinstance(v, (np.generic, Number)):
|
||||||
|
self.__dict__[k] *= 0
|
||||||
|
if np.isnan(self.__dict__[k]):
|
||||||
|
self.__dict__[k] = 0
|
||||||
|
else:
|
||||||
|
self.__dict__[k] = None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def empty(batch: 'Batch') -> 'Batch':
|
def empty(batch: 'Batch', index: Union[
|
||||||
|
str, slice, int, np.integer, np.ndarray, List[int]] = None
|
||||||
|
) -> 'Batch':
|
||||||
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
|
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
|
||||||
``None`` filled, the shape is the same as the given
|
``None`` filled, the shape is the same as the given
|
||||||
:class:`~tianshou.data.Batch`.
|
:class:`~tianshou.data.Batch`.
|
||||||
"""
|
"""
|
||||||
batch = Batch(**batch)
|
return deepcopy(batch).empty_(index)
|
||||||
batch.empty_()
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return len(self)."""
|
"""Return len(self)."""
|
||||||
|
@ -200,10 +200,15 @@ class Collector(object):
|
|||||||
return
|
return
|
||||||
if isinstance(self.state, list):
|
if isinstance(self.state, list):
|
||||||
self.state[id] = None
|
self.state[id] = None
|
||||||
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
|
elif isinstance(self.state, torch.Tensor):
|
||||||
self.state[id] *= 0
|
self.state[id].zero_()
|
||||||
else: # Batch
|
elif isinstance(self.state, np.ndarray):
|
||||||
self.state[id].empty_()
|
if isinstance(self.state.dtype == np.object):
|
||||||
|
self.state[id] = None
|
||||||
|
else:
|
||||||
|
self.state[id] = 0
|
||||||
|
elif isinstance(self.state, Batch):
|
||||||
|
self.state.empty_(id)
|
||||||
|
|
||||||
def collect(self,
|
def collect(self,
|
||||||
n_step: int = 0,
|
n_step: int = 0,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user