Tianshou/tianshou/data/buffer.py

345 lines
13 KiB
Python
Raw Normal View History

2020-03-11 09:09:56 +08:00
import numpy as np
from tianshou.data.batch import Batch
class ReplayBuffer(object):
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
interaction between the policy and environment. It stores basically 6 types
of data, as mentioned in :class:`~tianshou.data.Batch`, based on
``numpy.ndarray``. Here is the usage:
2020-04-03 21:28:12 +08:00
::
>>> import numpy as np
2020-04-03 21:28:12 +08:00
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf)
3
2020-04-03 21:28:12 +08:00
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
2020-04-03 21:28:12 +08:00
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
>>> # move buf2's result into buf (meanwhile keep it chronologically)
2020-04-03 21:28:12 +08:00
>>> buf.update(buf2)
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
0., 0., 0., 0., 0., 0., 0.])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
2020-04-10 09:01:17 +08:00
Since version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
frame_stack sampling (typically for RNN usage) and ignoring storing the
next observation (save memory):
::
2020-04-10 09:01:17 +08:00
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
2020-04-11 16:54:27 +08:00
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1)
>>> print(buf)
ReplayBuffer(
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
act: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
rew: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
done: [0. 1. 0. 0. 0. 0. 1. 0. 0.],
obs_next: [0. 0. 0. 0. 0. 0. 0. 0. 0.],
info: [{} {} {} {} {} {} {} {} {}],
)
>>> index = np.arange(len(buf))
2020-04-10 09:01:17 +08:00
>>> print(buf.get(index, 'obs'))
[[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.]
[11. 11. 11. 11.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[ 7. 7. 7. 7.]
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
2020-04-10 09:01:17 +08:00
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
0.0
2020-04-11 16:54:27 +08:00
>>> # we can get obs_next through __getitem__, even if it doesn't store
>>> print(buf[:].obs_next)
[[ 7. 8. 9. 10.]
[ 7. 8. 9. 10.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[12. 13. 14. 15.]
[ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]]
2020-04-03 21:28:12 +08:00
"""
2020-03-13 17:49:22 +08:00
2020-04-10 09:01:17 +08:00
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):
2020-03-11 09:09:56 +08:00
super().__init__()
self._maxsize = size
2020-04-08 21:13:15 +08:00
self._stack = stack_num
2020-04-10 09:01:17 +08:00
self._save_s_ = not ignore_obs_next
2020-03-11 17:28:51 +08:00
self.reset()
2020-03-11 09:09:56 +08:00
def __len__(self):
2020-04-04 21:02:06 +08:00
"""Return len(self)."""
2020-03-11 09:09:56 +08:00
return self._size
def __repr__(self):
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in self.__dict__.keys():
if k[0] != '_' and self.__dict__[k] is not None:
rpl = '\n' + ' ' * (6 + len(k))
obj = str(self.__dict__[k]).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')\n'
else:
s = self.__class__.__name__ + '()\n'
return s
2020-03-11 09:09:56 +08:00
def _add_to_buffer(self, name, inst):
if inst is None:
if getattr(self, name, None) is None:
self.__dict__[name] = None
2020-03-11 09:09:56 +08:00
return
if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray):
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
elif isinstance(inst, dict):
2020-03-13 17:49:22 +08:00
self.__dict__[name] = np.array(
[{} for _ in range(self._maxsize)])
2020-03-11 09:38:14 +08:00
else: # assume `inst` is a number
2020-03-11 09:09:56 +08:00
self.__dict__[name] = np.zeros([self._maxsize])
2020-03-18 21:45:41 +08:00
if isinstance(inst, np.ndarray) and \
self.__dict__[name].shape[1:] != inst.shape:
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
2020-03-11 09:09:56 +08:00
self.__dict__[name][self._index] = inst
2020-03-14 21:48:31 +08:00
def update(self, buffer):
2020-04-04 21:02:06 +08:00
"""Move the data from the given buffer to self."""
2020-03-16 11:11:29 +08:00
i = begin = buffer._index % len(buffer)
while True:
2020-03-14 21:48:31 +08:00
self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
2020-04-10 09:01:17 +08:00
buffer.obs_next[i] if self._save_s_ else None,
buffer.info[i])
2020-03-16 11:11:29 +08:00
i = (i + 1) % len(buffer)
if i == begin:
break
2020-03-14 21:48:31 +08:00
def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None):
2020-04-04 21:02:06 +08:00
"""Add a batch of data into replay buffer."""
assert isinstance(info, dict), \
2020-03-13 17:49:22 +08:00
'You should return a dict in the last argument of env.step().'
2020-03-11 09:09:56 +08:00
self._add_to_buffer('obs', obs)
self._add_to_buffer('act', act)
self._add_to_buffer('rew', rew)
self._add_to_buffer('done', done)
2020-04-10 09:01:17 +08:00
if self._save_s_:
self._add_to_buffer('obs_next', obs_next)
2020-03-11 09:09:56 +08:00
self._add_to_buffer('info', info)
2020-03-28 15:14:41 +08:00
if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize
else:
self._size = self._index = self._index + 1
2020-03-11 09:09:56 +08:00
def reset(self):
2020-04-04 21:02:06 +08:00
"""Clear all the data in replay buffer."""
2020-03-11 09:09:56 +08:00
self._index = self._size = 0
2020-03-13 17:49:22 +08:00
def sample(self, batch_size):
"""Get a random sample from buffer with size equal to batch_size. \
Return all the data in the buffer if batch_size is ``0``.
2020-04-03 21:28:12 +08:00
:return: Sample data and its corresponding index inside the buffer.
"""
2020-03-12 22:20:33 +08:00
if batch_size > 0:
2020-03-13 17:49:22 +08:00
indice = np.random.choice(self._size, batch_size)
2020-03-12 22:20:33 +08:00
else:
2020-03-17 11:37:31 +08:00
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
2020-03-30 22:52:25 +08:00
return self[indice], indice
2020-03-11 09:09:56 +08:00
2020-04-10 09:01:17 +08:00
def get(self, indice, key):
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
where s is self.key, t is indice. The stack_num (here equals to 4) is
given from buffer initialization procedure.
"""
2020-04-10 09:01:17 +08:00
if not isinstance(indice, np.ndarray):
if np.isscalar(indice):
indice = np.array(indice)
elif isinstance(indice, slice):
indice = np.arange(
0 if indice.start is None else indice.start,
self._size if indice.stop is None else indice.stop,
1 if indice.step is None else indice.step)
# set last frame done to True
last_index = (self._index - 1 + self._size) % self._size
last_done, self.done[last_index] = self.done[last_index], True
2020-04-10 09:01:17 +08:00
if key == 'obs_next' and not self._save_s_:
indice += 1 - self.done[indice].astype(np.int)
indice[indice == self._size] = 0
key = 'obs'
if self._stack == 0:
self.done[last_index] = last_done
return self.__dict__[key][indice]
stack = []
2020-04-08 21:13:15 +08:00
for i in range(self._stack):
stack = [self.__dict__[key][indice]] + stack
pre_indice = indice - 1
pre_indice[pre_indice == -1] = self._size - 1
indice = pre_indice + self.done[pre_indice].astype(np.int)
indice[indice == self._size] = 0
self.done[last_index] = last_done
2020-04-08 21:13:15 +08:00
return np.stack(stack, axis=1)
2020-03-15 17:41:00 +08:00
def __getitem__(self, index):
2020-04-08 21:13:15 +08:00
"""Return a data batch: self[index]. If stack_num is set to be > 0,
return the stacked obs and obs_next with shape [batch, len, ...].
"""
2020-03-15 17:41:00 +08:00
return Batch(
2020-04-10 09:01:17 +08:00
obs=self.get(index, 'obs'),
2020-03-15 17:41:00 +08:00
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
2020-04-10 09:01:17 +08:00
obs_next=self.get(index, 'obs_next'),
2020-03-15 17:41:00 +08:00
info=self.info[index]
)
2020-03-11 09:09:56 +08:00
2020-03-28 15:14:41 +08:00
class ListReplayBuffer(ReplayBuffer):
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
2020-04-03 21:28:12 +08:00
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
.. seealso::
Please refer to :class:`~tianshou.data.ListReplayBuffer` for more
detailed explanation.
2020-04-03 21:28:12 +08:00
"""
2020-04-10 09:01:17 +08:00
def __init__(self, **kwargs):
super().__init__(size=0, ignore_obs_next=False, **kwargs)
2020-03-28 15:14:41 +08:00
def _add_to_buffer(self, name, inst):
if inst is None:
return
if self.__dict__.get(name, None) is None:
self.__dict__[name] = []
self.__dict__[name].append(inst)
def reset(self):
self._index = self._size = 0
for k in list(self.__dict__.keys()):
if not k.startswith('_'):
self.__dict__[k] = []
2020-03-11 09:09:56 +08:00
class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer"""
2020-03-13 17:49:22 +08:00
def __init__(self, size, alpha: float, beta: float,
mode: str = 'weight', **kwargs):
if mode != 'weight':
raise NotImplementedError
2020-04-10 09:01:17 +08:00
super().__init__(size, **kwargs)
self._alpha = alpha # prioritization exponent
self._beta = beta # importance sample soft coefficient
self._weight_sum = 0.0
self.weight = np.zeros(size, dtype=np.float64)
self._amortization_freq = 50
self._amortization_counter = 0
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=1.0):
"""Add a batch of data into replay buffer."""
self._weight_sum += np.abs(weight)**self._alpha - \
self.weight[self._index]
# we have to sacrifice some convenience for speed :(
self._add_to_buffer('weight', np.abs(weight)**self._alpha)
super().add(obs, act, rew, done, obs_next, info)
self._check_weight_sum()
def sample(self, batch_size: int = 0, importance_sample: bool = True):
""" Get a random sample from buffer with priority probability. \
Return all the data in the buffer if batch_size is ``0``.
2020-03-11 09:38:14 +08:00
:return: Sample data and its corresponding index inside the buffer.
"""
if batch_size > 0 and batch_size <= self._size:
# Multiple sampling of the same sample
# will cause weight update conflict
indice = np.random.choice(
self._size, batch_size,
p=(self.weight/self.weight.sum())[:self._size], replace=False)
# self._weight_sum is not work for the accuracy issue
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
elif batch_size == 0:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
else:
# if batch_size larger than len(self),
# it will lead to a bug in update weight
raise ValueError("batch_size should be less than len(self)")
batch = self[indice]
if importance_sample:
impt_weight = Batch(
impt_weight=1/np.power(
self._size*(batch.weight/self._weight_sum), self._beta))
batch.append(impt_weight)
self._check_weight_sum()
return batch, indice
2020-03-11 09:09:56 +08:00
2020-03-30 22:52:25 +08:00
def reset(self):
self._amortization_counter = 0
super().reset()
def update_weight(self, indice, new_weight: np.ndarray):
"""update priority weight by indice in this buffer
:param indice: indice you want to update weight
:param new_weight: new priority weight you wangt to update
"""
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
- self.weight[indice].sum()
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
def __getitem__(self, index):
return Batch(
obs=self.get(index, 'obs'),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.info[index],
weight=self.weight[index]
)
def _check_weight_sum(self):
# keep a accurate _weight_sum
self._amortization_counter += 1
if self._amortization_counter % self._amortization_freq == 0:
self._weight_sum = np.sum(self.weight)
self._amortization_counter = 0