Tianshou/tianshou/data/buffer.py
yingchengyang 5b49192a48
DQN Atari examples (#187)
This PR aims to provide the script of Atari DQN setting:
- A speedrun of PongNoFrameskip-v4 (finished, about half an hour in i7-8750 + GTX1060 with 1M environment steps)
- A general script for all atari game
Since we use multiple env for simulation, the result is slightly different from the original paper, but consider to be acceptable.

It also adds another parameter save_only_last_obs for replay buffer in order to save the memory.

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-30 05:48:09 +08:00

473 lines
18 KiB
Python

import torch
import numpy as np
from typing import Any, Tuple, Union, Optional
from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
class ReplayBuffer:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
interaction between the policy and environment. The current implementation
of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \
function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
The following code snippet illustrates its usage:
::
>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
0., 0., 0., 0., 0., 0., 0.])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
>>> len(buf)
13
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
observation (save memory in atari tasks), and multi-modal observation (see
issue#38):
::
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
... obs_next={'id': i + 1})
>>> print(buf) # you can see obs_next is not saved in buf
ReplayBuffer(
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
info: Batch(),
obs: Batch(
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
),
policy: Batch(),
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.]
[11. 11. 11. 11.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[ 7. 7. 7. 7.]
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0.0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7. 8. 9. 10.]
[ 7. 8. 9. 10.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[12. 13. 14. 15.]
[ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]]
:param int size: the size of replay buffer.
:param int stack_num: the frame-stack sampling argument, should be greater
than or equal to 1, defaults to 1 (no stacking).
:param bool ignore_obs_next: whether to store obs_next, defaults to
``False``.
:param bool save_only_last_obs: only save the last obs/obs_next when it has
a shape of (timestep, ...) because of temporal stacking, defaults to
``False``.
:param bool sample_avail: the parameter indicating sampling only available
index when using frame-stack sampling method, defaults to ``False``.
This feature is not supported in Prioritized Replay Buffer currently.
"""
def __init__(self, size: int, stack_num: int = 1,
ignore_obs_next: bool = False,
save_only_last_obs: bool = False,
sample_avail: bool = False) -> None:
super().__init__()
self._maxsize = size
self._indices = np.arange(size)
self._stack = None
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
self._avail_index = []
self._save_s_ = not ignore_obs_next
self._last_obs = save_only_last_obs
self._index = 0
self._size = 0
self._meta = Batch()
self.reset()
def __len__(self) -> int:
"""Return len(self)."""
return self._size
def __repr__(self) -> str:
"""Return str(self)."""
return self.__class__.__name__ + self._meta.__repr__()[5:]
def __getattr__(self, key: str) -> Any:
"""Return self.key"""
try:
return self._meta[key]
except KeyError as e:
raise AttributeError from e
def __setstate__(self, state):
"""Unpickling interface. We need it because pickling buffer does not
work out-of-the-box (``buffer.__getattr__`` is customized).
"""
self.__dict__.update(state)
def _add_to_buffer(self, name: str, inst: Any) -> None:
try:
value = self._meta.__dict__[name]
except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name]
if isinstance(inst, (np.ndarray, torch.Tensor)) \
and value.shape[1:] != inst.shape:
raise ValueError(
"Cannot add data to a buffer with different shape, with key "
f"{name}, expect {value.shape[1:]}, given {inst.shape}.")
try:
value[self._index] = inst
except KeyError:
for key in set(inst.keys()).difference(value.__dict__.keys()):
value.__dict__[key] = _create_value(inst[key], self._maxsize)
value[self._index] = inst
@property
def stack_num(self):
return self._stack
@stack_num.setter
def stack_num(self, num):
assert num > 0, 'stack_num should greater than 0'
self._stack = num
def update(self, buffer: 'ReplayBuffer') -> None:
"""Move the data from the given buffer to self."""
if len(buffer) == 0:
return
i = begin = buffer._index % len(buffer)
stack_num_orig = buffer.stack_num
buffer.stack_num = 1
while True:
self.add(**buffer[i])
i = (i + 1) % len(buffer)
if i == begin:
break
buffer.stack_num = stack_num_orig
def add(self,
obs: Union[dict, Batch, np.ndarray, float],
act: Union[dict, Batch, np.ndarray, float],
rew: Union[int, float],
done: Union[bool, int],
obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None,
info: Optional[Union[dict, Batch]] = {},
policy: Optional[Union[dict, Batch]] = {},
**kwargs) -> None:
"""Add a batch of data into replay buffer."""
assert isinstance(info, (dict, Batch)), \
'You should return a dict in the last argument of env.step().'
if self._last_obs:
obs = obs[-1]
self._add_to_buffer('obs', obs)
self._add_to_buffer('act', act)
self._add_to_buffer('rew', rew)
self._add_to_buffer('done', done)
if self._save_s_:
if obs_next is None:
obs_next = Batch()
elif self._last_obs:
obs_next = obs_next[-1]
self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info)
self._add_to_buffer('policy', policy)
# maintain available index for frame-stack sampling
if self._avail:
# update current frame
avail = sum(self.done[i] for i in range(
self._index - self.stack_num + 1, self._index)) == 0
if self._size < self.stack_num - 1:
avail = False
if avail and self._index not in self._avail_index:
self._avail_index.append(self._index)
elif not avail and self._index in self._avail_index:
self._avail_index.remove(self._index)
# remove the later available frame because of broken storage
t = (self._index + self.stack_num - 1) % self._maxsize
if t in self._avail_index:
self._avail_index.remove(t)
if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize
else:
self._size = self._index = self._index + 1
def reset(self) -> None:
"""Clear all the data in replay buffer."""
self._index = 0
self._size = 0
self._avail_index = []
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with size equal to batch_size. \
Return all the data in the buffer if batch_size is ``0``.
:return: Sample data and its corresponding index inside the buffer.
"""
if batch_size > 0:
_all = self._avail_index if self._avail else self._size
indice = np.random.choice(_all, batch_size)
else:
if self._avail:
indice = np.array(self._avail_index)
else:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
assert len(indice) > 0, 'No available indice can be sampled.'
return self[indice], indice
def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
where s is self.key, t is indice. The stack_num (here equals to 4) is
given from buffer initialization procedure.
"""
if stack_num is None:
stack_num = self.stack_num
if stack_num == 1: # the most often case
if key != 'obs_next' or self._save_s_:
val = self._meta.__dict__[key]
try:
return val[indice]
except IndexError as e:
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
return Batch()
indice = self._indices[:self._size][indice]
done = self._meta.__dict__['done']
if key == 'obs_next' and not self._save_s_:
indice += 1 - done[indice].astype(np.int)
indice[indice == self._size] = 0
key = 'obs'
val = self._meta.__dict__[key]
try:
if stack_num == 1:
return val[indice]
stack = []
for _ in range(stack_num):
stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1)
pre_indice[pre_indice == -1] = self._size - 1
indice = np.asarray(
pre_indice + done[pre_indice].astype(np.int))
indice[indice == self._size] = 0
if isinstance(val, Batch):
stack = Batch.stack(stack, axis=indice.ndim)
else:
stack = np.stack(stack, axis=indice.ndim)
return stack
except IndexError as e:
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
return Batch()
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
"""Return a data batch: self[index]. If stack_num is larger than 1,
return the stacked obs and obs_next with shape [batch, len, ...].
"""
return Batch(
obs=self.get(index, 'obs'),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'),
policy=self.get(index, 'policy'),
)
class ListReplayBuffer(ReplayBuffer):
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``. Therefore,
it does not support advanced indexing, which means you cannot sample a
batch of data out of it. It is typically used for storing data.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
explanation.
"""
def __init__(self, **kwargs) -> None:
super().__init__(size=0, ignore_obs_next=False, **kwargs)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
def _add_to_buffer(
self, name: str,
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
if self._meta.__dict__.get(name) is None:
self._meta.__dict__[name] = []
self._meta.__dict__[name].append(inst)
def reset(self) -> None:
self._index = self._size = 0
for k in list(self._meta.__dict__.keys()):
if isinstance(self._meta.__dict__[k], list):
self._meta.__dict__[k] = []
class PrioritizedReplayBuffer(ReplayBuffer):
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952
:param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
explanation.
"""
def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
super().__init__(size, **kwargs)
assert alpha > 0. and beta >= 0.
self._alpha, self._beta = alpha, beta
self._max_prio = 1.
self._min_prio = 1.
# bypass the check
self._weight = SegmentTree(size)
self.__eps = np.finfo(np.float32).eps.item()
def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key"""
if key == 'weight':
return self._weight
return super().__getattr__(key)
def add(self,
obs: Union[dict, Batch, np.ndarray, float],
act: Union[dict, Batch, np.ndarray, float],
rew: Union[int, float],
done: Union[bool, int],
obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None,
info: Optional[Union[dict, Batch]] = {},
policy: Optional[Union[dict, Batch]] = {},
weight: Optional[float] = None,
**kwargs) -> None:
"""Add a batch of data into replay buffer."""
if weight is None:
weight = self._max_prio
else:
weight = np.abs(weight)
self._max_prio = max(self._max_prio, weight)
self._min_prio = min(self._min_prio, weight)
self.weight[self._index] = weight ** self._alpha
super().add(obs, act, rew, done, obs_next, info, policy)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with priority probability. Return
all the data in the buffer if batch_size is ``0``.
:return: Sample data and its corresponding index inside the buffer.
The ``weight`` in the returned Batch is the weight on loss function
to de-bias the sampling process (some transition tuples are sampled
more often so their losses are weighted less).
"""
assert self._size > 0, 'Cannot sample a buffer with 0 size!'
if batch_size == 0:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
else:
scalar = np.random.rand(batch_size) * self.weight.reduce()
indice = self.weight.get_prefix_sum_idx(scalar)
batch = self[indice]
# impt_weight
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
return batch, indice
def update_weight(self, indice: Union[np.ndarray],
new_weight: Union[np.ndarray, torch.Tensor]) -> None:
"""Update priority weight by indice in this buffer.
:param np.ndarray indice: indice you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
weight = np.abs(to_numpy(new_weight)) + self.__eps
self.weight[indice] = weight ** self._alpha
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
return Batch(
obs=self.get(index, 'obs'),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'),
policy=self.get(index, 'policy'),
weight=self.weight[index],
)