2020-04-28 20:56:02 +08:00
|
|
|
import pprint
|
2020-03-11 09:09:56 +08:00
|
|
|
import numpy as np
|
2020-05-12 11:31:47 +08:00
|
|
|
from copy import deepcopy
|
|
|
|
from typing import Tuple, Union, Optional
|
|
|
|
|
2020-03-11 09:09:56 +08:00
|
|
|
from tianshou.data.batch import Batch
|
|
|
|
|
|
|
|
|
|
|
|
class ReplayBuffer(object):
|
2020-04-05 18:34:45 +08:00
|
|
|
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
|
2020-04-30 16:31:40 +08:00
|
|
|
interaction between the policy and environment. It stores basically 7 types
|
2020-04-05 18:34:45 +08:00
|
|
|
of data, as mentioned in :class:`~tianshou.data.Batch`, based on
|
|
|
|
``numpy.ndarray``. Here is the usage:
|
2020-04-03 21:28:12 +08:00
|
|
|
::
|
|
|
|
|
2020-04-09 21:36:53 +08:00
|
|
|
>>> import numpy as np
|
2020-04-03 21:28:12 +08:00
|
|
|
>>> from tianshou.data import ReplayBuffer
|
|
|
|
>>> buf = ReplayBuffer(size=20)
|
|
|
|
>>> for i in range(3):
|
|
|
|
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
2020-04-05 18:34:45 +08:00
|
|
|
>>> len(buf)
|
|
|
|
3
|
2020-04-03 21:28:12 +08:00
|
|
|
>>> buf.obs
|
|
|
|
# since we set size = 20, len(buf.obs) == 20.
|
|
|
|
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
|
|
0., 0., 0., 0.])
|
|
|
|
|
|
|
|
>>> buf2 = ReplayBuffer(size=10)
|
|
|
|
>>> for i in range(15):
|
|
|
|
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
2020-04-05 18:34:45 +08:00
|
|
|
>>> len(buf2)
|
|
|
|
10
|
2020-04-03 21:28:12 +08:00
|
|
|
>>> buf2.obs
|
|
|
|
# since its size = 10, it only stores the last 10 steps' result.
|
|
|
|
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
|
|
|
|
|
2020-04-05 18:34:45 +08:00
|
|
|
>>> # move buf2's result into buf (meanwhile keep it chronologically)
|
2020-04-03 21:28:12 +08:00
|
|
|
>>> buf.update(buf2)
|
|
|
|
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
|
|
|
|
0., 0., 0., 0., 0., 0., 0.])
|
|
|
|
|
|
|
|
>>> # get a random sample from buffer
|
|
|
|
>>> # the batch_data is equal to buf[incide].
|
|
|
|
>>> batch_data, indice = buf.sample(batch_size=4)
|
|
|
|
>>> batch_data.obs == buf[indice].obs
|
|
|
|
array([ True, True, True, True])
|
2020-04-09 19:53:45 +08:00
|
|
|
|
2020-04-29 17:48:48 +08:00
|
|
|
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
|
|
|
(typically for RNN usage, see issue#19), ignoring storing the next
|
|
|
|
observation (save memory in atari tasks), and multi-modal observation (see
|
|
|
|
issue#38, need version >= 0.2.3):
|
2020-04-09 19:53:45 +08:00
|
|
|
::
|
|
|
|
|
2020-04-10 09:01:17 +08:00
|
|
|
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
2020-04-09 19:53:45 +08:00
|
|
|
>>> for i in range(16):
|
|
|
|
... done = i % 5 == 0
|
2020-04-29 17:48:48 +08:00
|
|
|
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
|
|
|
|
... obs_next={'id': i + 1})
|
|
|
|
>>> print(buf) # you can see obs_next is not saved in buf
|
2020-04-09 21:36:53 +08:00
|
|
|
ReplayBuffer(
|
2020-04-29 17:48:48 +08:00
|
|
|
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
|
|
|
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
|
|
|
|
info: array([{}, {}, {}, {}, {}, {}, {}, {}, {}], dtype=object),
|
|
|
|
obs: Batch(
|
|
|
|
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
|
|
|
),
|
|
|
|
policy: Batch(),
|
|
|
|
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
2020-04-09 21:36:53 +08:00
|
|
|
)
|
2020-04-09 19:53:45 +08:00
|
|
|
>>> index = np.arange(len(buf))
|
2020-04-29 17:48:48 +08:00
|
|
|
>>> print(buf.get(index, 'obs').id)
|
2020-04-09 19:53:45 +08:00
|
|
|
[[ 7. 7. 8. 9.]
|
|
|
|
[ 7. 8. 9. 10.]
|
|
|
|
[11. 11. 11. 11.]
|
|
|
|
[11. 11. 11. 12.]
|
|
|
|
[11. 11. 12. 13.]
|
|
|
|
[11. 12. 13. 14.]
|
|
|
|
[12. 13. 14. 15.]
|
|
|
|
[ 7. 7. 7. 7.]
|
|
|
|
[ 7. 7. 7. 8.]]
|
|
|
|
>>> # here is another way to get the stacked data
|
|
|
|
>>> # (stack only for obs and obs_next)
|
2020-04-29 17:48:48 +08:00
|
|
|
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
2020-04-09 19:53:45 +08:00
|
|
|
0.0
|
2020-04-29 17:48:48 +08:00
|
|
|
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
|
|
|
>>> print(buf[:].obs_next.id)
|
2020-04-11 16:54:27 +08:00
|
|
|
[[ 7. 8. 9. 10.]
|
|
|
|
[ 7. 8. 9. 10.]
|
|
|
|
[11. 11. 11. 12.]
|
|
|
|
[11. 11. 12. 13.]
|
|
|
|
[11. 12. 13. 14.]
|
|
|
|
[12. 13. 14. 15.]
|
|
|
|
[12. 13. 14. 15.]
|
|
|
|
[ 7. 7. 7. 8.]
|
|
|
|
[ 7. 7. 8. 9.]]
|
2020-04-03 21:28:12 +08:00
|
|
|
"""
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __init__(self, size: int, stack_num: Optional[int] = 0,
|
|
|
|
ignore_obs_next: Optional[bool] = False, **kwargs) -> None:
|
2020-03-11 09:09:56 +08:00
|
|
|
super().__init__()
|
|
|
|
self._maxsize = size
|
2020-04-08 21:13:15 +08:00
|
|
|
self._stack = stack_num
|
2020-04-10 09:01:17 +08:00
|
|
|
self._save_s_ = not ignore_obs_next
|
2020-04-28 20:56:02 +08:00
|
|
|
self._meta = {}
|
2020-03-11 17:28:51 +08:00
|
|
|
self.reset()
|
2020-03-11 09:09:56 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __len__(self) -> int:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Return len(self)."""
|
2020-03-11 09:09:56 +08:00
|
|
|
return self._size
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __repr__(self) -> str:
|
2020-04-09 19:53:45 +08:00
|
|
|
"""Return str(self)."""
|
|
|
|
s = self.__class__.__name__ + '(\n'
|
|
|
|
flag = False
|
2020-04-29 12:14:53 +08:00
|
|
|
for k in sorted(list(self.__dict__) + list(self._meta)):
|
2020-04-28 20:56:02 +08:00
|
|
|
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
|
2020-04-29 12:14:53 +08:00
|
|
|
k in self._meta):
|
2020-04-09 19:53:45 +08:00
|
|
|
rpl = '\n' + ' ' * (6 + len(k))
|
2020-04-28 20:56:02 +08:00
|
|
|
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
2020-04-09 19:53:45 +08:00
|
|
|
s += f' {k}: {obj},\n'
|
|
|
|
flag = True
|
|
|
|
if flag:
|
2020-04-29 12:14:53 +08:00
|
|
|
s += ')'
|
2020-04-09 19:53:45 +08:00
|
|
|
else:
|
2020-04-29 12:14:53 +08:00
|
|
|
s = self.__class__.__name__ + '()'
|
2020-04-09 19:53:45 +08:00
|
|
|
return s
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __getattr__(self, key: str) -> Union[Batch, np.ndarray]:
|
2020-04-28 20:56:02 +08:00
|
|
|
"""Return self.key"""
|
2020-04-29 12:14:53 +08:00
|
|
|
if key not in self._meta:
|
|
|
|
if key not in self.__dict__:
|
2020-04-28 20:56:02 +08:00
|
|
|
raise AttributeError(key)
|
|
|
|
return self.__dict__[key]
|
|
|
|
d = {}
|
|
|
|
for k_ in self._meta[key]:
|
|
|
|
k__ = '_' + key + '@' + k_
|
|
|
|
d[k_] = self.__dict__[k__]
|
2020-04-29 12:14:53 +08:00
|
|
|
return Batch(**d)
|
2020-04-28 20:56:02 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def _add_to_buffer(
|
|
|
|
self, name: str,
|
|
|
|
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
2020-03-11 09:09:56 +08:00
|
|
|
if inst is None:
|
2020-04-09 19:53:45 +08:00
|
|
|
if getattr(self, name, None) is None:
|
|
|
|
self.__dict__[name] = None
|
2020-03-11 09:09:56 +08:00
|
|
|
return
|
2020-04-29 12:14:53 +08:00
|
|
|
if name in self._meta:
|
2020-04-28 20:56:02 +08:00
|
|
|
for k in inst.keys():
|
|
|
|
self._add_to_buffer('_' + name + '@' + k, inst[k])
|
|
|
|
return
|
2020-03-11 09:09:56 +08:00
|
|
|
if self.__dict__.get(name, None) is None:
|
|
|
|
if isinstance(inst, np.ndarray):
|
|
|
|
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
2020-04-29 12:14:53 +08:00
|
|
|
elif isinstance(inst, dict) or isinstance(inst, Batch):
|
2020-04-28 20:56:02 +08:00
|
|
|
if name == 'info':
|
|
|
|
self.__dict__[name] = np.array(
|
|
|
|
[{} for _ in range(self._maxsize)])
|
|
|
|
else:
|
|
|
|
if self._meta.get(name, None) is None:
|
2020-04-29 12:14:53 +08:00
|
|
|
self._meta[name] = list(inst.keys())
|
2020-04-28 20:56:02 +08:00
|
|
|
for k in inst.keys():
|
|
|
|
k_ = '_' + name + '@' + k
|
|
|
|
self._add_to_buffer(k_, inst[k])
|
2020-03-11 09:38:14 +08:00
|
|
|
else: # assume `inst` is a number
|
2020-03-11 09:09:56 +08:00
|
|
|
self.__dict__[name] = np.zeros([self._maxsize])
|
2020-03-18 21:45:41 +08:00
|
|
|
if isinstance(inst, np.ndarray) and \
|
|
|
|
self.__dict__[name].shape[1:] != inst.shape:
|
2020-04-28 20:56:02 +08:00
|
|
|
raise ValueError(
|
|
|
|
"Cannot add data to a buffer with different shape, "
|
|
|
|
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
|
|
|
|
f"given shape: {inst.shape}.")
|
2020-04-29 12:14:53 +08:00
|
|
|
if name not in self._meta:
|
2020-05-12 11:31:47 +08:00
|
|
|
if name == 'info':
|
|
|
|
inst = deepcopy(inst)
|
2020-04-28 20:56:02 +08:00
|
|
|
self.__dict__[name][self._index] = inst
|
2020-03-11 09:09:56 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def update(self, buffer: 'ReplayBuffer') -> None:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Move the data from the given buffer to self."""
|
2020-03-16 11:11:29 +08:00
|
|
|
i = begin = buffer._index % len(buffer)
|
|
|
|
while True:
|
2020-03-14 21:48:31 +08:00
|
|
|
self.add(
|
2020-04-09 19:53:45 +08:00
|
|
|
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
|
2020-04-10 09:01:17 +08:00
|
|
|
buffer.obs_next[i] if self._save_s_ else None,
|
2020-04-29 17:48:48 +08:00
|
|
|
buffer.info[i], buffer.policy[i])
|
2020-03-16 11:11:29 +08:00
|
|
|
i = (i + 1) % len(buffer)
|
|
|
|
if i == begin:
|
|
|
|
break
|
2020-03-14 21:48:31 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def add(self,
|
|
|
|
obs: Union[dict, np.ndarray],
|
|
|
|
act: Union[np.ndarray, float],
|
|
|
|
rew: float,
|
|
|
|
done: bool,
|
|
|
|
obs_next: Optional[Union[dict, np.ndarray]] = None,
|
|
|
|
info: Optional[dict] = {},
|
|
|
|
policy: Optional[Union[dict, Batch]] = {},
|
|
|
|
**kwargs) -> None:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Add a batch of data into replay buffer."""
|
2020-03-26 09:01:20 +08:00
|
|
|
assert isinstance(info, dict), \
|
2020-03-13 17:49:22 +08:00
|
|
|
'You should return a dict in the last argument of env.step().'
|
2020-03-11 09:09:56 +08:00
|
|
|
self._add_to_buffer('obs', obs)
|
|
|
|
self._add_to_buffer('act', act)
|
|
|
|
self._add_to_buffer('rew', rew)
|
|
|
|
self._add_to_buffer('done', done)
|
2020-04-10 09:01:17 +08:00
|
|
|
if self._save_s_:
|
|
|
|
self._add_to_buffer('obs_next', obs_next)
|
2020-03-11 09:09:56 +08:00
|
|
|
self._add_to_buffer('info', info)
|
2020-04-28 20:56:02 +08:00
|
|
|
self._add_to_buffer('policy', policy)
|
2020-03-28 15:14:41 +08:00
|
|
|
if self._maxsize > 0:
|
|
|
|
self._size = min(self._size + 1, self._maxsize)
|
|
|
|
self._index = (self._index + 1) % self._maxsize
|
|
|
|
else:
|
|
|
|
self._size = self._index = self._index + 1
|
2020-03-11 09:09:56 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def reset(self) -> None:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Clear all the data in replay buffer."""
|
2020-03-11 09:09:56 +08:00
|
|
|
self._index = self._size = 0
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Get a random sample from buffer with size equal to batch_size. \
|
|
|
|
Return all the data in the buffer if batch_size is ``0``.
|
2020-04-03 21:28:12 +08:00
|
|
|
|
|
|
|
:return: Sample data and its corresponding index inside the buffer.
|
|
|
|
"""
|
2020-03-12 22:20:33 +08:00
|
|
|
if batch_size > 0:
|
2020-03-13 17:49:22 +08:00
|
|
|
indice = np.random.choice(self._size, batch_size)
|
2020-03-12 22:20:33 +08:00
|
|
|
else:
|
2020-03-17 11:37:31 +08:00
|
|
|
indice = np.concatenate([
|
|
|
|
np.arange(self._index, self._size),
|
|
|
|
np.arange(0, self._index),
|
|
|
|
])
|
2020-03-30 22:52:25 +08:00
|
|
|
return self[indice], indice
|
2020-03-11 09:09:56 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def get(self, indice: Union[slice, np.ndarray], key: str,
|
|
|
|
stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
|
2020-04-09 19:53:45 +08:00
|
|
|
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
|
|
|
|
where s is self.key, t is indice. The stack_num (here equals to 4) is
|
|
|
|
given from buffer initialization procedure.
|
|
|
|
"""
|
2020-04-28 20:56:02 +08:00
|
|
|
if stack_num is None:
|
|
|
|
stack_num = self._stack
|
2020-04-10 09:01:17 +08:00
|
|
|
if not isinstance(indice, np.ndarray):
|
|
|
|
if np.isscalar(indice):
|
|
|
|
indice = np.array(indice)
|
|
|
|
elif isinstance(indice, slice):
|
|
|
|
indice = np.arange(
|
2020-04-29 12:14:53 +08:00
|
|
|
0 if indice.start is None
|
|
|
|
else self._size - indice.start if indice.start < 0
|
|
|
|
else indice.start,
|
|
|
|
self._size if indice.stop is None
|
|
|
|
else self._size - indice.stop if indice.stop < 0
|
|
|
|
else indice.stop,
|
2020-04-10 09:01:17 +08:00
|
|
|
1 if indice.step is None else indice.step)
|
2020-04-09 19:53:45 +08:00
|
|
|
# set last frame done to True
|
|
|
|
last_index = (self._index - 1 + self._size) % self._size
|
|
|
|
last_done, self.done[last_index] = self.done[last_index], True
|
2020-04-10 09:01:17 +08:00
|
|
|
if key == 'obs_next' and not self._save_s_:
|
|
|
|
indice += 1 - self.done[indice].astype(np.int)
|
|
|
|
indice[indice == self._size] = 0
|
|
|
|
key = 'obs'
|
2020-04-28 20:56:02 +08:00
|
|
|
if stack_num == 0:
|
2020-04-10 09:01:17 +08:00
|
|
|
self.done[last_index] = last_done
|
2020-04-28 20:56:02 +08:00
|
|
|
if key in self._meta:
|
2020-04-29 12:14:53 +08:00
|
|
|
return {k: self.__dict__['_' + key + '@' + k][indice]
|
2020-04-28 20:56:02 +08:00
|
|
|
for k in self._meta[key]}
|
|
|
|
else:
|
|
|
|
return self.__dict__[key][indice]
|
|
|
|
if key in self._meta:
|
|
|
|
many_keys = self._meta[key]
|
2020-04-29 12:14:53 +08:00
|
|
|
stack = {k: [] for k in self._meta[key]}
|
2020-04-28 20:56:02 +08:00
|
|
|
else:
|
|
|
|
stack = []
|
|
|
|
many_keys = None
|
|
|
|
for i in range(stack_num):
|
|
|
|
if many_keys is not None:
|
|
|
|
for k_ in many_keys:
|
2020-04-29 12:14:53 +08:00
|
|
|
k__ = '_' + key + '@' + k_
|
|
|
|
stack[k_] = [self.__dict__[k__][indice]] + stack[k_]
|
2020-04-28 20:56:02 +08:00
|
|
|
else:
|
|
|
|
stack = [self.__dict__[key][indice]] + stack
|
2020-04-09 19:53:45 +08:00
|
|
|
pre_indice = indice - 1
|
|
|
|
pre_indice[pre_indice == -1] = self._size - 1
|
|
|
|
indice = pre_indice + self.done[pre_indice].astype(np.int)
|
|
|
|
indice[indice == self._size] = 0
|
|
|
|
self.done[last_index] = last_done
|
2020-04-28 20:56:02 +08:00
|
|
|
if many_keys is not None:
|
|
|
|
for k in stack:
|
|
|
|
stack[k] = np.stack(stack[k], axis=1)
|
2020-04-29 12:14:53 +08:00
|
|
|
stack = Batch(**stack)
|
2020-04-28 20:56:02 +08:00
|
|
|
else:
|
|
|
|
stack = np.stack(stack, axis=1)
|
|
|
|
return stack
|
2020-04-08 21:13:15 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
|
2020-04-08 21:13:15 +08:00
|
|
|
"""Return a data batch: self[index]. If stack_num is set to be > 0,
|
|
|
|
return the stacked obs and obs_next with shape [batch, len, ...].
|
|
|
|
"""
|
2020-03-15 17:41:00 +08:00
|
|
|
return Batch(
|
2020-04-10 09:01:17 +08:00
|
|
|
obs=self.get(index, 'obs'),
|
2020-03-15 17:41:00 +08:00
|
|
|
act=self.act[index],
|
2020-04-30 16:31:40 +08:00
|
|
|
# act_=self.get(index, 'act'), # stacked action, for RNN
|
2020-03-15 17:41:00 +08:00
|
|
|
rew=self.rew[index],
|
|
|
|
done=self.done[index],
|
2020-04-10 09:01:17 +08:00
|
|
|
obs_next=self.get(index, 'obs_next'),
|
2020-04-28 20:56:02 +08:00
|
|
|
info=self.info[index],
|
|
|
|
policy=self.get(index, 'policy'),
|
2020-03-15 17:41:00 +08:00
|
|
|
)
|
|
|
|
|
2020-03-11 09:09:56 +08:00
|
|
|
|
2020-03-28 15:14:41 +08:00
|
|
|
class ListReplayBuffer(ReplayBuffer):
|
2020-04-05 18:34:45 +08:00
|
|
|
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
|
|
|
|
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
|
2020-04-03 21:28:12 +08:00
|
|
|
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
|
2020-04-09 21:36:53 +08:00
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
2020-04-28 20:56:02 +08:00
|
|
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
|
2020-04-09 21:36:53 +08:00
|
|
|
detailed explanation.
|
2020-04-03 21:28:12 +08:00
|
|
|
"""
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __init__(self, **kwargs) -> None:
|
2020-04-10 09:01:17 +08:00
|
|
|
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
2020-03-28 15:14:41 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def _add_to_buffer(
|
|
|
|
self, name: str,
|
|
|
|
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
2020-03-28 15:14:41 +08:00
|
|
|
if inst is None:
|
|
|
|
return
|
|
|
|
if self.__dict__.get(name, None) is None:
|
|
|
|
self.__dict__[name] = []
|
2020-05-12 11:31:47 +08:00
|
|
|
if name == 'info':
|
|
|
|
inst = deepcopy(inst)
|
2020-03-28 15:14:41 +08:00
|
|
|
self.__dict__[name].append(inst)
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def reset(self) -> None:
|
2020-03-28 15:14:41 +08:00
|
|
|
self._index = self._size = 0
|
2020-04-29 12:14:53 +08:00
|
|
|
for k in list(self.__dict__):
|
|
|
|
if isinstance(self.__dict__[k], list):
|
2020-03-28 15:14:41 +08:00
|
|
|
self.__dict__[k] = []
|
|
|
|
|
|
|
|
|
2020-03-11 09:09:56 +08:00
|
|
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
2020-04-28 20:56:02 +08:00
|
|
|
"""Prioritized replay buffer implementation.
|
|
|
|
|
2020-04-29 12:14:53 +08:00
|
|
|
:param float alpha: the prioritization exponent.
|
|
|
|
:param float beta: the importance sample soft coefficient.
|
|
|
|
:param str mode: defaults to ``weight``.
|
|
|
|
|
2020-04-28 20:56:02 +08:00
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
|
|
|
|
detailed explanation.
|
|
|
|
"""
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __init__(self, size: int, alpha: float, beta: float,
|
|
|
|
mode: Optional[str] = 'weight', **kwargs) -> None:
|
2020-04-26 12:05:58 +08:00
|
|
|
if mode != 'weight':
|
|
|
|
raise NotImplementedError
|
2020-04-10 09:01:17 +08:00
|
|
|
super().__init__(size, **kwargs)
|
2020-04-29 12:14:53 +08:00
|
|
|
self._alpha = alpha
|
|
|
|
self._beta = beta
|
2020-04-26 12:05:58 +08:00
|
|
|
self._weight_sum = 0.0
|
|
|
|
self.weight = np.zeros(size, dtype=np.float64)
|
|
|
|
self._amortization_freq = 50
|
|
|
|
self._amortization_counter = 0
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def add(self,
|
|
|
|
obs: Union[dict, np.ndarray],
|
|
|
|
act: Union[np.ndarray, float],
|
|
|
|
rew: float,
|
|
|
|
done: bool,
|
|
|
|
obs_next: Optional[Union[dict, np.ndarray]] = None,
|
|
|
|
info: Optional[dict] = {},
|
|
|
|
policy: Optional[Union[dict, Batch]] = {},
|
|
|
|
weight: Optional[float] = 1.0,
|
|
|
|
**kwargs) -> None:
|
2020-04-26 12:05:58 +08:00
|
|
|
"""Add a batch of data into replay buffer."""
|
2020-05-12 11:31:47 +08:00
|
|
|
self._weight_sum += np.abs(weight) ** self._alpha - \
|
2020-04-26 12:05:58 +08:00
|
|
|
self.weight[self._index]
|
|
|
|
# we have to sacrifice some convenience for speed :(
|
2020-04-28 20:56:02 +08:00
|
|
|
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
|
|
|
super().add(obs, act, rew, done, obs_next, info, policy)
|
2020-04-26 12:05:58 +08:00
|
|
|
self._check_weight_sum()
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def sample(self, batch_size: Optional[int] = 0,
|
|
|
|
importance_sample: Optional[bool] = True
|
|
|
|
) -> Tuple[Batch, np.ndarray]:
|
2020-04-28 20:56:02 +08:00
|
|
|
"""Get a random sample from buffer with priority probability. \
|
2020-04-26 12:05:58 +08:00
|
|
|
Return all the data in the buffer if batch_size is ``0``.
|
2020-03-11 09:38:14 +08:00
|
|
|
|
2020-04-26 12:05:58 +08:00
|
|
|
:return: Sample data and its corresponding index inside the buffer.
|
|
|
|
"""
|
|
|
|
if batch_size > 0 and batch_size <= self._size:
|
|
|
|
# Multiple sampling of the same sample
|
|
|
|
# will cause weight update conflict
|
|
|
|
indice = np.random.choice(
|
|
|
|
self._size, batch_size,
|
2020-04-28 20:56:02 +08:00
|
|
|
p=(self.weight / self.weight.sum())[:self._size],
|
|
|
|
replace=False)
|
2020-04-26 12:05:58 +08:00
|
|
|
# self._weight_sum is not work for the accuracy issue
|
|
|
|
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
|
|
|
|
elif batch_size == 0:
|
|
|
|
indice = np.concatenate([
|
|
|
|
np.arange(self._index, self._size),
|
|
|
|
np.arange(0, self._index),
|
|
|
|
])
|
|
|
|
else:
|
|
|
|
# if batch_size larger than len(self),
|
|
|
|
# it will lead to a bug in update weight
|
|
|
|
raise ValueError("batch_size should be less than len(self)")
|
|
|
|
batch = self[indice]
|
|
|
|
if importance_sample:
|
|
|
|
impt_weight = Batch(
|
2020-04-28 20:56:02 +08:00
|
|
|
impt_weight=1 / np.power(
|
|
|
|
self._size * (batch.weight / self._weight_sum),
|
|
|
|
self._beta))
|
2020-04-26 12:05:58 +08:00
|
|
|
batch.append(impt_weight)
|
|
|
|
self._check_weight_sum()
|
|
|
|
return batch, indice
|
2020-03-11 09:09:56 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def reset(self) -> None:
|
2020-04-26 12:05:58 +08:00
|
|
|
self._amortization_counter = 0
|
|
|
|
super().reset()
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def update_weight(self, indice: Union[slice, np.ndarray],
|
|
|
|
new_weight: np.ndarray) -> None:
|
2020-04-28 20:56:02 +08:00
|
|
|
"""Update priority weight by indice in this buffer.
|
2020-04-26 12:05:58 +08:00
|
|
|
|
2020-04-29 12:14:53 +08:00
|
|
|
:param np.ndarray indice: indice you want to update weight
|
|
|
|
:param np.ndarray new_weight: new priority weight you wangt to update
|
2020-04-26 12:05:58 +08:00
|
|
|
"""
|
|
|
|
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
|
|
|
|
- self.weight[indice].sum()
|
|
|
|
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
|
2020-04-26 12:05:58 +08:00
|
|
|
return Batch(
|
|
|
|
obs=self.get(index, 'obs'),
|
|
|
|
act=self.act[index],
|
2020-04-30 16:31:40 +08:00
|
|
|
# act_=self.get(index, 'act'), # stacked action, for RNN
|
2020-04-26 12:05:58 +08:00
|
|
|
rew=self.rew[index],
|
|
|
|
done=self.done[index],
|
|
|
|
obs_next=self.get(index, 'obs_next'),
|
|
|
|
info=self.info[index],
|
2020-04-28 20:56:02 +08:00
|
|
|
weight=self.weight[index],
|
|
|
|
policy=self.get(index, 'policy'),
|
2020-04-26 12:05:58 +08:00
|
|
|
)
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def _check_weight_sum(self) -> None:
|
2020-04-29 12:14:53 +08:00
|
|
|
# keep an accurate _weight_sum
|
2020-04-26 12:05:58 +08:00
|
|
|
self._amortization_counter += 1
|
|
|
|
if self._amortization_counter % self._amortization_freq == 0:
|
|
|
|
self._weight_sum = np.sum(self.weight)
|
|
|
|
self._amortization_counter = 0
|