Improve Batch (#128)

* minor polish

* improve and implement Batch.cat_

* bugfix for buffer.sample with field impt_weight

* restore the usage of a.cat_(b)

* fix 2 bugs in batch and add corresponding unittest

* code fix for update

* update is_empty to recognize empty over empty; bugfix for len

* bugfix for update and add testcase

* add testcase of update

* fix docs

* fix docs

* fix docs [ci skip]

* fix docs [ci skip]

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
youkaichao 2020-07-11 21:46:01 +08:00 committed by n+e
parent 2564e989fb
commit affeec13de
4 changed files with 162 additions and 70 deletions

View File

@ -5,6 +5,7 @@ on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
if: "!contains(github.event.head_commit.message, 'ci skip')"
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]

View File

@ -10,7 +10,17 @@ from tianshou.data import Batch, to_torch
def test_batch():
assert list(Batch()) == []
assert Batch().is_empty()
assert Batch(b={'c': {}}).is_empty()
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
b = Batch()
b.update()
assert b.is_empty()
b.update(c=[3, 5])
assert np.allclose(b.c, [3, 5])
# mimic the behavior of dict.update, where kwargs can overwrite keys
b.update({'a': 2}, a=3)
assert b.a == 3
with pytest.raises(AssertionError):
Batch({1: 2})
batch = Batch(a=[torch.ones(3), torch.ones(3)])
@ -86,6 +96,18 @@ def test_batch():
assert batch3.a.d.f[0] == 5.0
with pytest.raises(KeyError):
batch3.a.d[0] = Batch(f=5.0, g=0.0)
# auto convert
batch4 = Batch(a=np.array(['a', 'b']))
assert batch4.a.dtype == np.object # auto convert to np.object
batch4.update(a=np.array(['c', 'd']))
assert list(batch4.a) == ['c', 'd']
assert batch4.a.dtype == np.object # auto convert to np.object
batch5 = Batch(a=np.array([{'index': 0}]))
assert isinstance(batch5.a, Batch)
assert np.allclose(batch5.a.index, [0])
batch5.b = np.array([{'index': 1}])
assert isinstance(batch5.b, Batch)
assert np.allclose(batch5.b.index, [1])
def test_batch_over_batch():
@ -100,6 +122,11 @@ def test_batch_over_batch():
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
batch2.update(batch2.b, six=[6, 6, 6])
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0])
assert np.allclose(batch2.six, [6, 6, 6])
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d)
batch3.cat_(Batch(c=[6, 7, 8], b=d))
@ -124,18 +151,32 @@ def test_batch_over_batch():
def test_batch_cat_and_stack():
# test cat with compatible keys
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat((b1, b2))
b12_cat_out = Batch.cat([b1, b2])
b12_cat_in = copy.deepcopy(b1)
b12_cat_in.cat_(b2)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b12_cat_in.a.d.e.ndim == 1
b12_stack = Batch.stack((b1, b2))
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
# test batch with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
b3 = Batch(a=np.zeros((3, 4)),
b=torch.ones((2, 5)),
c=Batch(d=[[1], [2]]))

View File

@ -259,8 +259,7 @@ class Batch:
v_ = None
if not isinstance(v, np.ndarray) and \
all(isinstance(e, torch.Tensor) for e in v):
v_ = torch.stack(v)
self.__dict__[k] = v_
self.__dict__[k] = torch.stack(v)
continue
else:
v_ = np.asanyarray(v)
@ -294,7 +293,8 @@ class Batch:
value = np.array(value)
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
elif isinstance(value, dict):
elif isinstance(value, dict) or isinstance(value, np.ndarray) \
and value.dtype == np.object and _is_batch_set(value):
value = Batch(value)
self.__dict__[key] = value
@ -333,9 +333,8 @@ class Batch:
else:
raise IndexError("Cannot access item from empty Batch object.")
def __setitem__(
self,
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(value, np.ndarray):
@ -454,10 +453,8 @@ class Batch:
elif isinstance(v, Batch):
v.to_numpy()
def to_torch(self,
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu'
) -> None:
def to_torch(self, dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu') -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an in-place
operation.
"""
@ -473,66 +470,111 @@ class Batch:
v = v.type(dtype)
self.__dict__[k] = v
elif isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype:
must_update_tensor = True
elif v.device.type != device.type:
must_update_tensor = True
elif device.index is not None and \
if dtype is not None and v.dtype != dtype or \
v.device.type != device.type or \
device.index is not None and \
device.index != v.device.index:
must_update_tensor = True
else:
must_update_tensor = False
if must_update_tensor:
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v.to(device)
elif isinstance(v, Batch):
v.to_torch(dtype, device)
def append(self, batch: 'Batch') -> None:
warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be '
'removed soon, please use '
':meth:`~tianshou.data.Batch.cat`')
return self.cat_(batch)
def cat_(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object into current
batch.
def cat_(self,
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
"""Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects
into current batch.
"""
assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated in-place!'
for k, v in batch.items():
if v is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = deepcopy(v)
elif isinstance(v, np.ndarray) and v.ndim > 0:
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, Batch):
self.__dict__[k].cat_(v)
if isinstance(batches, Batch):
batches = [batches]
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if len(self.__dict__) > 0:
batches = [self] + list(batches)
# partial keys will be padded by zeros
# with the shape of [len, rest_shape]
lens = [len(x) for x in batches]
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
self.__dict__[k] = Batch.cat(v)
elif all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.cat(v)
else:
s = 'No support for method "cat" with type '\
f'{type(v)} in class Batch.'
raise TypeError(s)
v = np.concatenate(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_partial = set.union(*keys_map) - keys_shared
_assert_type_keys(keys_partial)
for k in keys_partial:
is_dict = False
value = None
for i, e in enumerate(batches):
val = e.get(k, None)
if val is not None:
if isinstance(val, (dict, Batch)):
is_dict = True
else: # np.ndarray or torch.Tensor
value = val
break
if is_dict:
self.__dict__[k] = Batch.cat(
[e.get(k, Batch()) for e in batches])
else:
if isinstance(value, np.ndarray):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = np.zeros(shape, dtype=value.dtype)
arrs.append(e.get(k, pad))
self.__dict__[k] = np.concatenate(arrs)
elif isinstance(value, torch.Tensor):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = torch.zeros(shape,
dtype=value.dtype,
device=value.device)
arrs.append(e.get(k, pad))
self.__dict__[k] = torch.cat(arrs)
else:
raise TypeError(
f"cannot cat value with type {type(value)}, we only "
"support dict, Batch, np.ndarray, and torch.Tensor")
@staticmethod
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a single
new batch.
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a
single new batch. For keys that are not shared across all batches,
batches that do not have these keys will be padded by zeros with
appropriate shapes. E.g.
::
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
"""
batch = Batch()
for batch_ in batches:
if isinstance(batch_, dict):
batch_ = Batch(batch_)
batch.cat_(batch_)
batch.cat_(batches)
return batch
def stack_(self,
batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None:
"""Stack a :class:`~tianshou.data.Batch` object i into current batch.
"""Stack a list of :class:`~tianshou.data.Batch` object into current
batch.
"""
if len(self.__dict__) > 0:
batches = [self] + list(batches)
@ -566,8 +608,8 @@ class Batch:
@staticmethod
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a single new
batch.
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
new batch.
"""
batch = Batch()
batch.stack_(batches, axis)
@ -611,11 +653,24 @@ class Batch:
"""
return deepcopy(batch).empty_(index)
def update(self, batch: Optional[Union[dict, 'Batch']] = None,
**kwargs) -> None:
"""Update this batch from another dict/Batch."""
if batch is None:
self.update(kwargs)
return
if isinstance(batch, dict):
batch = Batch(batch)
for k, v in batch.items():
self.__dict__[k] = v
if kwargs:
self.update(kwargs)
def __len__(self) -> int:
"""Return len(self)."""
r = []
for v in self.__dict__.values():
if isinstance(v, Batch) and len(v.__dict__) == 0:
if isinstance(v, Batch) and v.is_empty():
continue
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
@ -627,7 +682,9 @@ class Batch:
return min(r)
def is_empty(self):
return len(self.__dict__.keys()) == 0
return not any(
not x.is_empty() if isinstance(x, Batch)
else hasattr(x, '__len__') and len(x) > 0 for x in self.values())
@property
def shape(self) -> List[int]:

View File

@ -108,8 +108,7 @@ class ReplayBuffer:
super().__init__()
self._maxsize = size
self._stack = stack_num
assert stack_num != 1, \
'stack_num should greater than 1'
assert stack_num != 1, 'stack_num should greater than 1'
self._avail = sample_avail and stack_num > 1
self._avail_index = []
self._save_s_ = not ignore_obs_next
@ -136,12 +135,11 @@ class ReplayBuffer:
except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name]
if isinstance(inst, np.ndarray) and \
value.shape[1:] != inst.shape:
if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape:
raise ValueError(
"Cannot add data to a buffer with different shape, key: "
f"{name}, expect shape: {value.shape[1:]}"
f", given shape: {inst.shape}.")
f"{name}, expect shape: {value.shape[1:]}, "
f"given shape: {inst.shape}.")
try:
value[self._index] = inst
except KeyError:
@ -357,7 +355,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._weight_sum = 0.0
self._amortization_freq = 50
self._replace = replace
self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
self._meta.weight = np.zeros(size, dtype=np.float64)
def add(self,
obs: Union[dict, np.ndarray],
@ -372,7 +370,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
"""Add a batch of data into replay buffer."""
# we have to sacrifice some convenience for speed
self._weight_sum += np.abs(weight) ** self._alpha - \
self._meta.__dict__['weight'][self._index]
self._meta.weight[self._index]
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
super().add(obs, act, rew, done, obs_next, info, policy)
@ -410,14 +408,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
f"batch_size should be less than {len(self)}, \
or set replace=True")
batch = self[indice]
impt_weight = Batch(
impt_weight=(self._size * p) ** (-self._beta))
batch.cat_(impt_weight)
batch["impt_weight"] = (self._size * p) ** (-self._beta)
return batch, indice
def reset(self) -> None:
super().reset()
def update_weight(self, indice: Union[slice, np.ndarray],
new_weight: np.ndarray) -> None:
"""Update priority weight by indice in this buffer.