Improve Batch (#128)
* minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
2564e989fb
commit
affeec13de
1
.github/workflows/pytest.yml
vendored
1
.github/workflows/pytest.yml
vendored
@ -5,6 +5,7 @@ on: [push, pull_request]
|
|||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [3.6, 3.7, 3.8]
|
python-version: [3.6, 3.7, 3.8]
|
||||||
|
@ -10,7 +10,17 @@ from tianshou.data import Batch, to_torch
|
|||||||
def test_batch():
|
def test_batch():
|
||||||
assert list(Batch()) == []
|
assert list(Batch()) == []
|
||||||
assert Batch().is_empty()
|
assert Batch().is_empty()
|
||||||
|
assert Batch(b={'c': {}}).is_empty()
|
||||||
|
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
|
||||||
assert not Batch(a=[1, 2, 3]).is_empty()
|
assert not Batch(a=[1, 2, 3]).is_empty()
|
||||||
|
b = Batch()
|
||||||
|
b.update()
|
||||||
|
assert b.is_empty()
|
||||||
|
b.update(c=[3, 5])
|
||||||
|
assert np.allclose(b.c, [3, 5])
|
||||||
|
# mimic the behavior of dict.update, where kwargs can overwrite keys
|
||||||
|
b.update({'a': 2}, a=3)
|
||||||
|
assert b.a == 3
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
Batch({1: 2})
|
Batch({1: 2})
|
||||||
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
batch = Batch(a=[torch.ones(3), torch.ones(3)])
|
||||||
@ -86,6 +96,18 @@ def test_batch():
|
|||||||
assert batch3.a.d.f[0] == 5.0
|
assert batch3.a.d.f[0] == 5.0
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
batch3.a.d[0] = Batch(f=5.0, g=0.0)
|
batch3.a.d[0] = Batch(f=5.0, g=0.0)
|
||||||
|
# auto convert
|
||||||
|
batch4 = Batch(a=np.array(['a', 'b']))
|
||||||
|
assert batch4.a.dtype == np.object # auto convert to np.object
|
||||||
|
batch4.update(a=np.array(['c', 'd']))
|
||||||
|
assert list(batch4.a) == ['c', 'd']
|
||||||
|
assert batch4.a.dtype == np.object # auto convert to np.object
|
||||||
|
batch5 = Batch(a=np.array([{'index': 0}]))
|
||||||
|
assert isinstance(batch5.a, Batch)
|
||||||
|
assert np.allclose(batch5.a.index, [0])
|
||||||
|
batch5.b = np.array([{'index': 1}])
|
||||||
|
assert isinstance(batch5.b, Batch)
|
||||||
|
assert np.allclose(batch5.b.index, [1])
|
||||||
|
|
||||||
|
|
||||||
def test_batch_over_batch():
|
def test_batch_over_batch():
|
||||||
@ -100,6 +122,11 @@ def test_batch_over_batch():
|
|||||||
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
|
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
|
||||||
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
|
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
|
||||||
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
|
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
|
||||||
|
batch2.update(batch2.b, six=[6, 6, 6])
|
||||||
|
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
|
||||||
|
assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5])
|
||||||
|
assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0])
|
||||||
|
assert np.allclose(batch2.six, [6, 6, 6])
|
||||||
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
|
||||||
batch3 = Batch(c=[6, 7, 8], b=d)
|
batch3 = Batch(c=[6, 7, 8], b=d)
|
||||||
batch3.cat_(Batch(c=[6, 7, 8], b=d))
|
batch3.cat_(Batch(c=[6, 7, 8], b=d))
|
||||||
@ -124,18 +151,32 @@ def test_batch_over_batch():
|
|||||||
|
|
||||||
|
|
||||||
def test_batch_cat_and_stack():
|
def test_batch_cat_and_stack():
|
||||||
|
# test cat with compatible keys
|
||||||
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
|
||||||
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
|
||||||
b12_cat_out = Batch.cat((b1, b2))
|
b12_cat_out = Batch.cat([b1, b2])
|
||||||
b12_cat_in = copy.deepcopy(b1)
|
b12_cat_in = copy.deepcopy(b1)
|
||||||
b12_cat_in.cat_(b2)
|
b12_cat_in.cat_(b2)
|
||||||
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
||||||
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
|
||||||
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
|
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
|
||||||
assert b12_cat_in.a.d.e.ndim == 1
|
assert b12_cat_in.a.d.e.ndim == 1
|
||||||
|
|
||||||
b12_stack = Batch.stack((b1, b2))
|
b12_stack = Batch.stack((b1, b2))
|
||||||
assert isinstance(b12_stack.a.d.e, np.ndarray)
|
assert isinstance(b12_stack.a.d.e, np.ndarray)
|
||||||
assert b12_stack.a.d.e.ndim == 2
|
assert b12_stack.a.d.e.ndim == 2
|
||||||
|
|
||||||
|
# test batch with incompatible keys
|
||||||
|
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
|
||||||
|
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
|
||||||
|
test = Batch.cat([b1, b2])
|
||||||
|
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
|
||||||
|
b=torch.cat([torch.zeros(3, 3), b2.b]),
|
||||||
|
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
|
||||||
|
assert np.allclose(test.a, ans.a)
|
||||||
|
assert torch.allclose(test.b, ans.b)
|
||||||
|
assert np.allclose(test.common.c, ans.common.c)
|
||||||
|
|
||||||
b3 = Batch(a=np.zeros((3, 4)),
|
b3 = Batch(a=np.zeros((3, 4)),
|
||||||
b=torch.ones((2, 5)),
|
b=torch.ones((2, 5)),
|
||||||
c=Batch(d=[[1], [2]]))
|
c=Batch(d=[[1], [2]]))
|
||||||
|
@ -259,8 +259,7 @@ class Batch:
|
|||||||
v_ = None
|
v_ = None
|
||||||
if not isinstance(v, np.ndarray) and \
|
if not isinstance(v, np.ndarray) and \
|
||||||
all(isinstance(e, torch.Tensor) for e in v):
|
all(isinstance(e, torch.Tensor) for e in v):
|
||||||
v_ = torch.stack(v)
|
self.__dict__[k] = torch.stack(v)
|
||||||
self.__dict__[k] = v_
|
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
v_ = np.asanyarray(v)
|
v_ = np.asanyarray(v)
|
||||||
@ -294,7 +293,8 @@ class Batch:
|
|||||||
value = np.array(value)
|
value = np.array(value)
|
||||||
if not issubclass(value.dtype.type, (np.bool_, np.number)):
|
if not issubclass(value.dtype.type, (np.bool_, np.number)):
|
||||||
value = value.astype(np.object)
|
value = value.astype(np.object)
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict) or isinstance(value, np.ndarray) \
|
||||||
|
and value.dtype == np.object and _is_batch_set(value):
|
||||||
value = Batch(value)
|
value = Batch(value)
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
|
|
||||||
@ -333,9 +333,8 @@ class Batch:
|
|||||||
else:
|
else:
|
||||||
raise IndexError("Cannot access item from empty Batch object.")
|
raise IndexError("Cannot access item from empty Batch object.")
|
||||||
|
|
||||||
def __setitem__(
|
def __setitem__(self, index: Union[
|
||||||
self,
|
str, slice, int, np.integer, np.ndarray, List[int]],
|
||||||
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
|
|
||||||
value: Any) -> None:
|
value: Any) -> None:
|
||||||
"""Assign value to self[index]."""
|
"""Assign value to self[index]."""
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
@ -454,10 +453,8 @@ class Batch:
|
|||||||
elif isinstance(v, Batch):
|
elif isinstance(v, Batch):
|
||||||
v.to_numpy()
|
v.to_numpy()
|
||||||
|
|
||||||
def to_torch(self,
|
def to_torch(self, dtype: Optional[torch.dtype] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
device: Union[str, int, torch.device] = 'cpu') -> None:
|
||||||
device: Union[str, int, torch.device] = 'cpu'
|
|
||||||
) -> None:
|
|
||||||
"""Change all numpy.ndarray to torch.Tensor. This is an in-place
|
"""Change all numpy.ndarray to torch.Tensor. This is an in-place
|
||||||
operation.
|
operation.
|
||||||
"""
|
"""
|
||||||
@ -473,66 +470,111 @@ class Batch:
|
|||||||
v = v.type(dtype)
|
v = v.type(dtype)
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
if dtype is not None and v.dtype != dtype:
|
if dtype is not None and v.dtype != dtype or \
|
||||||
must_update_tensor = True
|
v.device.type != device.type or \
|
||||||
elif v.device.type != device.type:
|
device.index is not None and \
|
||||||
must_update_tensor = True
|
|
||||||
elif device.index is not None and \
|
|
||||||
device.index != v.device.index:
|
device.index != v.device.index:
|
||||||
must_update_tensor = True
|
|
||||||
else:
|
|
||||||
must_update_tensor = False
|
|
||||||
if must_update_tensor:
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
v = v.type(dtype)
|
v = v.type(dtype)
|
||||||
self.__dict__[k] = v.to(device)
|
self.__dict__[k] = v.to(device)
|
||||||
elif isinstance(v, Batch):
|
elif isinstance(v, Batch):
|
||||||
v.to_torch(dtype, device)
|
v.to_torch(dtype, device)
|
||||||
|
|
||||||
def append(self, batch: 'Batch') -> None:
|
def cat_(self,
|
||||||
warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be '
|
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
|
||||||
'removed soon, please use '
|
"""Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects
|
||||||
':meth:`~tianshou.data.Batch.cat`')
|
into current batch.
|
||||||
return self.cat_(batch)
|
|
||||||
|
|
||||||
def cat_(self, batch: 'Batch') -> None:
|
|
||||||
"""Concatenate a :class:`~tianshou.data.Batch` object into current
|
|
||||||
batch.
|
|
||||||
"""
|
"""
|
||||||
assert isinstance(batch, Batch), \
|
if isinstance(batches, Batch):
|
||||||
'Only Batch is allowed to be concatenated in-place!'
|
batches = [batches]
|
||||||
for k, v in batch.items():
|
if len(batches) == 0:
|
||||||
if v is None:
|
return
|
||||||
continue
|
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
if len(self.__dict__) > 0:
|
||||||
self.__dict__[k] = deepcopy(v)
|
batches = [self] + list(batches)
|
||||||
elif isinstance(v, np.ndarray) and v.ndim > 0:
|
# partial keys will be padded by zeros
|
||||||
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
|
# with the shape of [len, rest_shape]
|
||||||
elif isinstance(v, torch.Tensor):
|
lens = [len(x) for x in batches]
|
||||||
self.__dict__[k] = torch.cat([self.__dict__[k], v])
|
keys_map = list(map(lambda e: set(e.keys()), batches))
|
||||||
elif isinstance(v, Batch):
|
keys_shared = set.intersection(*keys_map)
|
||||||
self.__dict__[k].cat_(v)
|
values_shared = [
|
||||||
|
[e[k] for e in batches] for k in keys_shared]
|
||||||
|
_assert_type_keys(keys_shared)
|
||||||
|
for k, v in zip(keys_shared, values_shared):
|
||||||
|
if all(isinstance(e, (dict, Batch)) for e in v):
|
||||||
|
self.__dict__[k] = Batch.cat(v)
|
||||||
|
elif all(isinstance(e, torch.Tensor) for e in v):
|
||||||
|
self.__dict__[k] = torch.cat(v)
|
||||||
else:
|
else:
|
||||||
s = 'No support for method "cat" with type '\
|
v = np.concatenate(v)
|
||||||
f'{type(v)} in class Batch.'
|
if not issubclass(v.dtype.type, (np.bool_, np.number)):
|
||||||
raise TypeError(s)
|
v = v.astype(np.object)
|
||||||
|
self.__dict__[k] = v
|
||||||
|
keys_partial = set.union(*keys_map) - keys_shared
|
||||||
|
_assert_type_keys(keys_partial)
|
||||||
|
for k in keys_partial:
|
||||||
|
is_dict = False
|
||||||
|
value = None
|
||||||
|
for i, e in enumerate(batches):
|
||||||
|
val = e.get(k, None)
|
||||||
|
if val is not None:
|
||||||
|
if isinstance(val, (dict, Batch)):
|
||||||
|
is_dict = True
|
||||||
|
else: # np.ndarray or torch.Tensor
|
||||||
|
value = val
|
||||||
|
break
|
||||||
|
if is_dict:
|
||||||
|
self.__dict__[k] = Batch.cat(
|
||||||
|
[e.get(k, Batch()) for e in batches])
|
||||||
|
else:
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
arrs = []
|
||||||
|
for i, e in enumerate(batches):
|
||||||
|
shape = [lens[i]] + list(value.shape[1:])
|
||||||
|
pad = np.zeros(shape, dtype=value.dtype)
|
||||||
|
arrs.append(e.get(k, pad))
|
||||||
|
self.__dict__[k] = np.concatenate(arrs)
|
||||||
|
elif isinstance(value, torch.Tensor):
|
||||||
|
arrs = []
|
||||||
|
for i, e in enumerate(batches):
|
||||||
|
shape = [lens[i]] + list(value.shape[1:])
|
||||||
|
pad = torch.zeros(shape,
|
||||||
|
dtype=value.dtype,
|
||||||
|
device=value.device)
|
||||||
|
arrs.append(e.get(k, pad))
|
||||||
|
self.__dict__[k] = torch.cat(arrs)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"cannot cat value with type {type(value)}, we only "
|
||||||
|
"support dict, Batch, np.ndarray, and torch.Tensor")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
||||||
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a single
|
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a
|
||||||
new batch.
|
single new batch. For keys that are not shared across all batches,
|
||||||
|
batches that do not have these keys will be padded by zeros with
|
||||||
|
appropriate shapes. E.g.
|
||||||
|
::
|
||||||
|
|
||||||
|
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
|
||||||
|
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
|
||||||
|
>>> c = Batch.cat([a, b])
|
||||||
|
>>> c.a.shape
|
||||||
|
(7, 4)
|
||||||
|
>>> c.b.shape
|
||||||
|
(7, 3)
|
||||||
|
>>> c.common.c.shape
|
||||||
|
(7, 5)
|
||||||
"""
|
"""
|
||||||
batch = Batch()
|
batch = Batch()
|
||||||
for batch_ in batches:
|
batch.cat_(batches)
|
||||||
if isinstance(batch_, dict):
|
|
||||||
batch_ = Batch(batch_)
|
|
||||||
batch.cat_(batch_)
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def stack_(self,
|
def stack_(self,
|
||||||
batches: List[Union[dict, 'Batch']],
|
batches: List[Union[dict, 'Batch']],
|
||||||
axis: int = 0) -> None:
|
axis: int = 0) -> None:
|
||||||
"""Stack a :class:`~tianshou.data.Batch` object i into current batch.
|
"""Stack a list of :class:`~tianshou.data.Batch` object into current
|
||||||
|
batch.
|
||||||
"""
|
"""
|
||||||
if len(self.__dict__) > 0:
|
if len(self.__dict__) > 0:
|
||||||
batches = [self] + list(batches)
|
batches = [self] + list(batches)
|
||||||
@ -566,8 +608,8 @@ class Batch:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
||||||
"""Stack a :class:`~tianshou.data.Batch` object into a single new
|
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
|
||||||
batch.
|
new batch.
|
||||||
"""
|
"""
|
||||||
batch = Batch()
|
batch = Batch()
|
||||||
batch.stack_(batches, axis)
|
batch.stack_(batches, axis)
|
||||||
@ -611,11 +653,24 @@ class Batch:
|
|||||||
"""
|
"""
|
||||||
return deepcopy(batch).empty_(index)
|
return deepcopy(batch).empty_(index)
|
||||||
|
|
||||||
|
def update(self, batch: Optional[Union[dict, 'Batch']] = None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""Update this batch from another dict/Batch."""
|
||||||
|
if batch is None:
|
||||||
|
self.update(kwargs)
|
||||||
|
return
|
||||||
|
if isinstance(batch, dict):
|
||||||
|
batch = Batch(batch)
|
||||||
|
for k, v in batch.items():
|
||||||
|
self.__dict__[k] = v
|
||||||
|
if kwargs:
|
||||||
|
self.update(kwargs)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return len(self)."""
|
"""Return len(self)."""
|
||||||
r = []
|
r = []
|
||||||
for v in self.__dict__.values():
|
for v in self.__dict__.values():
|
||||||
if isinstance(v, Batch) and len(v.__dict__) == 0:
|
if isinstance(v, Batch) and v.is_empty():
|
||||||
continue
|
continue
|
||||||
elif hasattr(v, '__len__') and (not isinstance(
|
elif hasattr(v, '__len__') and (not isinstance(
|
||||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||||
@ -627,7 +682,9 @@ class Batch:
|
|||||||
return min(r)
|
return min(r)
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.__dict__.keys()) == 0
|
return not any(
|
||||||
|
not x.is_empty() if isinstance(x, Batch)
|
||||||
|
else hasattr(x, '__len__') and len(x) > 0 for x in self.values())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> List[int]:
|
def shape(self) -> List[int]:
|
||||||
|
@ -108,8 +108,7 @@ class ReplayBuffer:
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._maxsize = size
|
self._maxsize = size
|
||||||
self._stack = stack_num
|
self._stack = stack_num
|
||||||
assert stack_num != 1, \
|
assert stack_num != 1, 'stack_num should greater than 1'
|
||||||
'stack_num should greater than 1'
|
|
||||||
self._avail = sample_avail and stack_num > 1
|
self._avail = sample_avail and stack_num > 1
|
||||||
self._avail_index = []
|
self._avail_index = []
|
||||||
self._save_s_ = not ignore_obs_next
|
self._save_s_ = not ignore_obs_next
|
||||||
@ -136,12 +135,11 @@ class ReplayBuffer:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
|
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
|
||||||
value = self._meta.__dict__[name]
|
value = self._meta.__dict__[name]
|
||||||
if isinstance(inst, np.ndarray) and \
|
if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape:
|
||||||
value.shape[1:] != inst.shape:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot add data to a buffer with different shape, key: "
|
"Cannot add data to a buffer with different shape, key: "
|
||||||
f"{name}, expect shape: {value.shape[1:]}"
|
f"{name}, expect shape: {value.shape[1:]}, "
|
||||||
f", given shape: {inst.shape}.")
|
f"given shape: {inst.shape}.")
|
||||||
try:
|
try:
|
||||||
value[self._index] = inst
|
value[self._index] = inst
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -357,7 +355,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
self._weight_sum = 0.0
|
self._weight_sum = 0.0
|
||||||
self._amortization_freq = 50
|
self._amortization_freq = 50
|
||||||
self._replace = replace
|
self._replace = replace
|
||||||
self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
|
self._meta.weight = np.zeros(size, dtype=np.float64)
|
||||||
|
|
||||||
def add(self,
|
def add(self,
|
||||||
obs: Union[dict, np.ndarray],
|
obs: Union[dict, np.ndarray],
|
||||||
@ -372,7 +370,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
"""Add a batch of data into replay buffer."""
|
"""Add a batch of data into replay buffer."""
|
||||||
# we have to sacrifice some convenience for speed
|
# we have to sacrifice some convenience for speed
|
||||||
self._weight_sum += np.abs(weight) ** self._alpha - \
|
self._weight_sum += np.abs(weight) ** self._alpha - \
|
||||||
self._meta.__dict__['weight'][self._index]
|
self._meta.weight[self._index]
|
||||||
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
||||||
super().add(obs, act, rew, done, obs_next, info, policy)
|
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||||
|
|
||||||
@ -410,14 +408,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
f"batch_size should be less than {len(self)}, \
|
f"batch_size should be less than {len(self)}, \
|
||||||
or set replace=True")
|
or set replace=True")
|
||||||
batch = self[indice]
|
batch = self[indice]
|
||||||
impt_weight = Batch(
|
batch["impt_weight"] = (self._size * p) ** (-self._beta)
|
||||||
impt_weight=(self._size * p) ** (-self._beta))
|
|
||||||
batch.cat_(impt_weight)
|
|
||||||
return batch, indice
|
return batch, indice
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
super().reset()
|
|
||||||
|
|
||||||
def update_weight(self, indice: Union[slice, np.ndarray],
|
def update_weight(self, indice: Union[slice, np.ndarray],
|
||||||
new_weight: np.ndarray) -> None:
|
new_weight: np.ndarray) -> None:
|
||||||
"""Update priority weight by indice in this buffer.
|
"""Update priority weight by indice in this buffer.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user