Improve Batch (#128)
* minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
2564e989fb
commit
affeec13de
1
.github/workflows/pytest.yml
vendored
1
.github/workflows/pytest.yml
vendored
@ -5,6 +5,7 @@ on: [push, pull_request]
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
@ -10,7 +10,17 @@ from tianshou.data import Batch, to_torch
|
||||
def test_batch():
|
||||
assert list(Batch()) == []
|
||||
assert Batch().is_empty()
|
||||
assert Batch(b={'c': {}}).is_empty()
|
||||
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
|
||||
assert not Batch(a=[1, 2, 3]).is_empty()
|
||||
b = Batch()
|
||||
b.update()
|
||||
assert b.is_empty()
|
||||
b.update(c=[3, 5])
|
||||
assert np.allclose(b.c, [3, 5])
|
||||
# mimic the behavior of dict.update, where kwargs can overwrite keys
|
||||
b.update({'a': 2}, a=3)
|
||||
assert b.a == 3
|
||||
with pytest.raises(AssertionError):
|
||||
Batch({1: 2})
|
||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||
@ -86,6 +96,18 @@ def test_batch():
|
||||
assert batch3.a.d.f[0] == 5.0
|
||||
with pytest.raises(KeyError):
|
||||
batch3.a.d[0] = Batch(f=5.0, g=0.0)
|
||||
# auto convert
|
||||
batch4 = Batch(a=np.array(['a', 'b']))
|
||||
assert batch4.a.dtype == np.object # auto convert to np.object
|
||||
batch4.update(a=np.array(['c', 'd']))
|
||||
assert list(batch4.a) == ['c', 'd']
|
||||
assert batch4.a.dtype == np.object # auto convert to np.object
|
||||
batch5 = Batch(a=np.array([{'index': 0}]))
|
||||
assert isinstance(batch5.a, Batch)
|
||||
assert np.allclose(batch5.a.index, [0])
|
||||
batch5.b = np.array([{'index': 1}])
|
||||
assert isinstance(batch5.b, Batch)
|
||||
assert np.allclose(batch5.b.index, [1])
|
||||
|
||||
|
||||
def test_batch_over_batch():
|
||||
@ -100,6 +122,11 @@ def test_batch_over_batch():
|
||||
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
|
||||
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
|
||||
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
|
||||
batch2.update(batch2.b, six=[6, 6, 6])
|
||||
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
|
||||
assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5])
|
||||
assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0])
|
||||
assert np.allclose(batch2.six, [6, 6, 6])
|
||||
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
||||
batch3 = Batch(c=[6, 7, 8], b=d)
|
||||
batch3.cat_(Batch(c=[6, 7, 8], b=d))
|
||||
@ -124,18 +151,32 @@ def test_batch_over_batch():
|
||||
|
||||
|
||||
def test_batch_cat_and_stack():
|
||||
# test cat with compatible keys
|
||||
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
||||
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
||||
b12_cat_out = Batch.cat((b1, b2))
|
||||
b12_cat_out = Batch.cat([b1, b2])
|
||||
b12_cat_in = copy.deepcopy(b1)
|
||||
b12_cat_in.cat_(b2)
|
||||
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
||||
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
||||
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
|
||||
assert b12_cat_in.a.d.e.ndim == 1
|
||||
|
||||
b12_stack = Batch.stack((b1, b2))
|
||||
assert isinstance(b12_stack.a.d.e, np.ndarray)
|
||||
assert b12_stack.a.d.e.ndim == 2
|
||||
|
||||
# test batch with incompatible keys
|
||||
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
|
||||
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
|
||||
test = Batch.cat([b1, b2])
|
||||
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
|
||||
b=torch.cat([torch.zeros(3, 3), b2.b]),
|
||||
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
|
||||
assert np.allclose(test.a, ans.a)
|
||||
assert torch.allclose(test.b, ans.b)
|
||||
assert np.allclose(test.common.c, ans.common.c)
|
||||
|
||||
b3 = Batch(a=np.zeros((3, 4)),
|
||||
b=torch.ones((2, 5)),
|
||||
c=Batch(d=[[1], [2]]))
|
||||
|
@ -259,8 +259,7 @@ class Batch:
|
||||
v_ = None
|
||||
if not isinstance(v, np.ndarray) and \
|
||||
all(isinstance(e, torch.Tensor) for e in v):
|
||||
v_ = torch.stack(v)
|
||||
self.__dict__[k] = v_
|
||||
self.__dict__[k] = torch.stack(v)
|
||||
continue
|
||||
else:
|
||||
v_ = np.asanyarray(v)
|
||||
@ -294,7 +293,8 @@ class Batch:
|
||||
value = np.array(value)
|
||||
if not issubclass(value.dtype.type, (np.bool_, np.number)):
|
||||
value = value.astype(np.object)
|
||||
elif isinstance(value, dict):
|
||||
elif isinstance(value, dict) or isinstance(value, np.ndarray) \
|
||||
and value.dtype == np.object and _is_batch_set(value):
|
||||
value = Batch(value)
|
||||
self.__dict__[key] = value
|
||||
|
||||
@ -333,9 +333,8 @@ class Batch:
|
||||
else:
|
||||
raise IndexError("Cannot access item from empty Batch object.")
|
||||
|
||||
def __setitem__(
|
||||
self,
|
||||
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
def __setitem__(self, index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any) -> None:
|
||||
"""Assign value to self[index]."""
|
||||
if isinstance(value, np.ndarray):
|
||||
@ -454,10 +453,8 @@ class Batch:
|
||||
elif isinstance(v, Batch):
|
||||
v.to_numpy()
|
||||
|
||||
def to_torch(self,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = 'cpu'
|
||||
) -> None:
|
||||
def to_torch(self, dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = 'cpu') -> None:
|
||||
"""Change all numpy.ndarray to torch.Tensor. This is an in-place
|
||||
operation.
|
||||
"""
|
||||
@ -473,66 +470,111 @@ class Batch:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v
|
||||
elif isinstance(v, torch.Tensor):
|
||||
if dtype is not None and v.dtype != dtype:
|
||||
must_update_tensor = True
|
||||
elif v.device.type != device.type:
|
||||
must_update_tensor = True
|
||||
elif device.index is not None and \
|
||||
if dtype is not None and v.dtype != dtype or \
|
||||
v.device.type != device.type or \
|
||||
device.index is not None and \
|
||||
device.index != v.device.index:
|
||||
must_update_tensor = True
|
||||
else:
|
||||
must_update_tensor = False
|
||||
if must_update_tensor:
|
||||
if dtype is not None:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v.to(device)
|
||||
elif isinstance(v, Batch):
|
||||
v.to_torch(dtype, device)
|
||||
|
||||
def append(self, batch: 'Batch') -> None:
|
||||
warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be '
|
||||
'removed soon, please use '
|
||||
':meth:`~tianshou.data.Batch.cat`')
|
||||
return self.cat_(batch)
|
||||
|
||||
def cat_(self, batch: 'Batch') -> None:
|
||||
"""Concatenate a :class:`~tianshou.data.Batch` object into current
|
||||
batch.
|
||||
def cat_(self,
|
||||
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
|
||||
"""Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects
|
||||
into current batch.
|
||||
"""
|
||||
assert isinstance(batch, Batch), \
|
||||
'Only Batch is allowed to be concatenated in-place!'
|
||||
for k, v in batch.items():
|
||||
if v is None:
|
||||
continue
|
||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||
self.__dict__[k] = deepcopy(v)
|
||||
elif isinstance(v, np.ndarray) and v.ndim > 0:
|
||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
||||
elif isinstance(v, Batch):
|
||||
self.__dict__[k].cat_(v)
|
||||
if isinstance(batches, Batch):
|
||||
batches = [batches]
|
||||
if len(batches) == 0:
|
||||
return
|
||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||
if len(self.__dict__) > 0:
|
||||
batches = [self] + list(batches)
|
||||
# partial keys will be padded by zeros
|
||||
# with the shape of [len, rest_shape]
|
||||
lens = [len(x) for x in batches]
|
||||
keys_map = list(map(lambda e: set(e.keys()), batches))
|
||||
keys_shared = set.intersection(*keys_map)
|
||||
values_shared = [
|
||||
[e[k] for e in batches] for k in keys_shared]
|
||||
_assert_type_keys(keys_shared)
|
||||
for k, v in zip(keys_shared, values_shared):
|
||||
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||
self.__dict__[k] = Batch.cat(v)
|
||||
elif all(isinstance(e, torch.Tensor) for e in v):
|
||||
self.__dict__[k] = torch.cat(v)
|
||||
else:
|
||||
s = 'No support for method "cat" with type '\
|
||||
f'{type(v)} in class Batch.'
|
||||
raise TypeError(s)
|
||||
v = np.concatenate(v)
|
||||
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||
v = v.astype(np.object)
|
||||
self.__dict__[k] = v
|
||||
keys_partial = set.union(*keys_map) - keys_shared
|
||||
_assert_type_keys(keys_partial)
|
||||
for k in keys_partial:
|
||||
is_dict = False
|
||||
value = None
|
||||
for i, e in enumerate(batches):
|
||||
val = e.get(k, None)
|
||||
if val is not None:
|
||||
if isinstance(val, (dict, Batch)):
|
||||
is_dict = True
|
||||
else: # np.ndarray or torch.Tensor
|
||||
value = val
|
||||
break
|
||||
if is_dict:
|
||||
self.__dict__[k] = Batch.cat(
|
||||
[e.get(k, Batch()) for e in batches])
|
||||
else:
|
||||
if isinstance(value, np.ndarray):
|
||||
arrs = []
|
||||
for i, e in enumerate(batches):
|
||||
shape = [lens[i]] + list(value.shape[1:])
|
||||
pad = np.zeros(shape, dtype=value.dtype)
|
||||
arrs.append(e.get(k, pad))
|
||||
self.__dict__[k] = np.concatenate(arrs)
|
||||
elif isinstance(value, torch.Tensor):
|
||||
arrs = []
|
||||
for i, e in enumerate(batches):
|
||||
shape = [lens[i]] + list(value.shape[1:])
|
||||
pad = torch.zeros(shape,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
arrs.append(e.get(k, pad))
|
||||
self.__dict__[k] = torch.cat(arrs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"cannot cat value with type {type(value)}, we only "
|
||||
"support dict, Batch, np.ndarray, and torch.Tensor")
|
||||
|
||||
@staticmethod
|
||||
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
||||
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a single
|
||||
new batch.
|
||||
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a
|
||||
single new batch. For keys that are not shared across all batches,
|
||||
batches that do not have these keys will be padded by zeros with
|
||||
appropriate shapes. E.g.
|
||||
::
|
||||
|
||||
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
|
||||
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
|
||||
>>> c = Batch.cat([a, b])
|
||||
>>> c.a.shape
|
||||
(7, 4)
|
||||
>>> c.b.shape
|
||||
(7, 3)
|
||||
>>> c.common.c.shape
|
||||
(7, 5)
|
||||
"""
|
||||
batch = Batch()
|
||||
for batch_ in batches:
|
||||
if isinstance(batch_, dict):
|
||||
batch_ = Batch(batch_)
|
||||
batch.cat_(batch_)
|
||||
batch.cat_(batches)
|
||||
return batch
|
||||
|
||||
def stack_(self,
|
||||
batches: List[Union[dict, 'Batch']],
|
||||
axis: int = 0) -> None:
|
||||
"""Stack a :class:`~tianshou.data.Batch` object i into current batch.
|
||||
"""Stack a list of :class:`~tianshou.data.Batch` object into current
|
||||
batch.
|
||||
"""
|
||||
if len(self.__dict__) > 0:
|
||||
batches = [self] + list(batches)
|
||||
@ -566,8 +608,8 @@ class Batch:
|
||||
|
||||
@staticmethod
|
||||
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
||||
"""Stack a :class:`~tianshou.data.Batch` object into a single new
|
||||
batch.
|
||||
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
|
||||
new batch.
|
||||
"""
|
||||
batch = Batch()
|
||||
batch.stack_(batches, axis)
|
||||
@ -611,11 +653,24 @@ class Batch:
|
||||
"""
|
||||
return deepcopy(batch).empty_(index)
|
||||
|
||||
def update(self, batch: Optional[Union[dict, 'Batch']] = None,
|
||||
**kwargs) -> None:
|
||||
"""Update this batch from another dict/Batch."""
|
||||
if batch is None:
|
||||
self.update(kwargs)
|
||||
return
|
||||
if isinstance(batch, dict):
|
||||
batch = Batch(batch)
|
||||
for k, v in batch.items():
|
||||
self.__dict__[k] = v
|
||||
if kwargs:
|
||||
self.update(kwargs)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
r = []
|
||||
for v in self.__dict__.values():
|
||||
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
||||
if isinstance(v, Batch) and v.is_empty():
|
||||
continue
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||
@ -627,7 +682,9 @@ class Batch:
|
||||
return min(r)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.__dict__.keys()) == 0
|
||||
return not any(
|
||||
not x.is_empty() if isinstance(x, Batch)
|
||||
else hasattr(x, '__len__') and len(x) > 0 for x in self.values())
|
||||
|
||||
@property
|
||||
def shape(self) -> List[int]:
|
||||
|
@ -108,8 +108,7 @@ class ReplayBuffer:
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._stack = stack_num
|
||||
assert stack_num != 1, \
|
||||
'stack_num should greater than 1'
|
||||
assert stack_num != 1, 'stack_num should greater than 1'
|
||||
self._avail = sample_avail and stack_num > 1
|
||||
self._avail_index = []
|
||||
self._save_s_ = not ignore_obs_next
|
||||
@ -136,12 +135,11 @@ class ReplayBuffer:
|
||||
except KeyError:
|
||||
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
|
||||
value = self._meta.__dict__[name]
|
||||
if isinstance(inst, np.ndarray) and \
|
||||
value.shape[1:] != inst.shape:
|
||||
if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape:
|
||||
raise ValueError(
|
||||
"Cannot add data to a buffer with different shape, key: "
|
||||
f"{name}, expect shape: {value.shape[1:]}"
|
||||
f", given shape: {inst.shape}.")
|
||||
f"{name}, expect shape: {value.shape[1:]}, "
|
||||
f"given shape: {inst.shape}.")
|
||||
try:
|
||||
value[self._index] = inst
|
||||
except KeyError:
|
||||
@ -357,7 +355,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self._weight_sum = 0.0
|
||||
self._amortization_freq = 50
|
||||
self._replace = replace
|
||||
self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
|
||||
self._meta.weight = np.zeros(size, dtype=np.float64)
|
||||
|
||||
def add(self,
|
||||
obs: Union[dict, np.ndarray],
|
||||
@ -372,7 +370,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""Add a batch of data into replay buffer."""
|
||||
# we have to sacrifice some convenience for speed
|
||||
self._weight_sum += np.abs(weight) ** self._alpha - \
|
||||
self._meta.__dict__['weight'][self._index]
|
||||
self._meta.weight[self._index]
|
||||
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
||||
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||
|
||||
@ -410,14 +408,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
f"batch_size should be less than {len(self)}, \
|
||||
or set replace=True")
|
||||
batch = self[indice]
|
||||
impt_weight = Batch(
|
||||
impt_weight=(self._size * p) ** (-self._beta))
|
||||
batch.cat_(impt_weight)
|
||||
batch["impt_weight"] = (self._size * p) ** (-self._beta)
|
||||
return batch, indice
|
||||
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
|
||||
def update_weight(self, indice: Union[slice, np.ndarray],
|
||||
new_weight: np.ndarray) -> None:
|
||||
"""Update priority weight by indice in this buffer.
|
||||
|
Loading…
x
Reference in New Issue
Block a user