store RNN hidden states in policy._state and add sample_avail in buffer (#19)

This commit is contained in:
Trinkle23897 2020-06-29 12:18:52 +08:00
parent 60cfc373f8
commit e0f4862d01
4 changed files with 85 additions and 37 deletions

View File

@ -61,11 +61,13 @@ def test_ignore_obs_next(size=10):
def test_stack(size=5, bufsize=9, stack_num=4):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
obs = env.reset(1)
for i in range(15):
obs_next, rew, done, info = env.step(1)
buf.add(obs, 1, rew, done, None, info)
buf2.add(obs, 1, rew, done, None, info)
obs = obs_next
if done:
obs = env.reset(1)
@ -75,6 +77,10 @@ def test_stack(size=5, bufsize=9, stack_num=4):
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]]))
print(buf)
_, indice = buf2.sample(0)
assert indice == [2]
_, indice = buf2.sample(1)
assert indice.sum() == 2
def test_priortized_replaybuffer(size=32, bufsize=15):

View File

@ -97,10 +97,11 @@ class Batch:
function return 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
:class:`Batch` object can be initialized using wide variety of arguments,
starting with the key/value pairs or dictionary, but also list and Numpy
arrays of :class:`dict` or Batch instances. In which case, each element
is considered as an individual sample and get stacked together:
:class:`~tianshou.data.Batch` object can be initialized using wide variety
of arguments, starting with the key/value pairs or dictionary, but also
list and Numpy arrays of :class:`dict` or Batch instances. In which case,
each element is considered as an individual sample and get stacked
together:
::
>>> import numpy as np
@ -113,9 +114,9 @@ class Batch:
),
)
:class:`Batch` has the same API as a native Python :class:`dict`. In this
regard, one can access to stored data using string key, or iterate over
stored data:
:class:`~tianshou.data.Batch` has the same API as a native Python
:class:`dict`. In this regard, one can access to stored data using string
key, or iterate over stored data:
::
>>> from tianshou.data import Batch
@ -128,8 +129,8 @@ class Batch:
b: [5, 5]
:class:`Batch` is also reproduce partially the Numpy API for arrays. You
can access or iterate over the individual samples, if any:
:class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for
arrays. You can access or iterate over the individual samples, if any:
::
>>> import numpy as np
@ -219,11 +220,12 @@ class Batch:
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
Convenience helpers are available to convert in-place the
stored data into Numpy arrays or Torch tensors.
Convenience helpers are available to convert in-place the stored data into
Numpy arrays or Torch tensors.
Finally, note that Batch instance are serializable and therefore Pickle
compatible. This is especially important for distributed sampling.
Finally, note that :class:`~tianshou.data.Batch` instance are serializable
and therefore Pickle compatible. This is especially important for
distributed sampling.
"""
def __init__(self,

View File

@ -1,7 +1,7 @@
import numpy as np
from typing import Any, Tuple, Union, Optional
from .batch import Batch, _create_value
from tianshou.data.batch import Batch, _create_value
class ReplayBuffer:
@ -91,12 +91,27 @@ class ReplayBuffer:
[12. 13. 14. 15.]
[ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]]
:param int size: the size of replay buffer.
:param int stack_num: the frame-stack sampling argument, should be greater
than 1, defaults to 0 (no stacking).
:param bool ignore_obs_next: whether to store obs_next, defaults to
``False``.
:param bool sample_avail: the parameter indicating sampling only available
index when using frame-stack sampling method, defaults to ``False``.
This feature is not supported in Prioritized Replay Buffer currently.
"""
def __init__(self, size: int, stack_num: Optional[int] = 0,
ignore_obs_next: bool = False, **kwargs) -> None:
ignore_obs_next: bool = False,
sample_avail: bool = False, **kwargs) -> None:
super().__init__()
self._maxsize = size
self._stack = stack_num
assert stack_num != 1, \
'stack_num should greater than 1'
self._avail = sample_avail and stack_num > 1
self._avail_index = []
self._save_s_ = not ignore_obs_next
self._index = 0
self._size = 0
@ -146,7 +161,7 @@ class ReplayBuffer:
def add(self,
obs: Union[dict, Batch, np.ndarray],
act: Union[np.ndarray, float],
rew: float,
rew: Union[int, float],
done: bool,
obs_next: Optional[Union[dict, Batch, np.ndarray]] = None,
info: dict = {},
@ -165,6 +180,23 @@ class ReplayBuffer:
self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info)
self._add_to_buffer('policy', policy)
# maintain available index for frame-stack sampling
if self._avail:
# update current frame
avail = sum(self.done[i] for i in range(
self._index - self._stack + 1, self._index)) == 0
if self._size < self._stack - 1:
avail = False
if avail and self._index not in self._avail_index:
self._avail_index.append(self._index)
elif not avail and self._index in self._avail_index:
self._avail_index.remove(self._index)
# remove the later available frame because of broken storage
t = (self._index + self._stack - 1) % self._maxsize
if t in self._avail_index:
self._avail_index.remove(t)
if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize
@ -175,6 +207,7 @@ class ReplayBuffer:
"""Clear all the data in replay buffer."""
self._index = 0
self._size = 0
self._avail_index = []
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with size equal to batch_size. \
@ -183,12 +216,17 @@ class ReplayBuffer:
:return: Sample data and its corresponding index inside the buffer.
"""
if batch_size > 0:
indice = np.random.choice(self._size, batch_size)
_all = self._avail_index if self._avail else self._size
indice = np.random.choice(_all, batch_size)
else:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
if self._avail:
indice = np.array(self._avail_index)
else:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
assert len(indice) > 0, 'No available indice can be sampled.'
return self[indice], indice
def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
@ -247,11 +285,10 @@ class ReplayBuffer:
return Batch(
obs=self.get(index, 'obs'),
act=self.act[index],
# act_=self.get(index, 'act'), # stacked action, for RNN
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info', stack_num=0),
info=self.get(index, 'info'),
policy=self.get(index, 'policy')
)
@ -317,7 +354,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
def add(self,
obs: Union[dict, np.ndarray],
act: Union[np.ndarray, float],
rew: float,
rew: Union[int, float],
done: bool,
obs_next: Optional[Union[dict, np.ndarray]] = None,
info: dict = {},
@ -401,11 +438,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
- self.weight[indice].sum()
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
return Batch(
obs=self.get(index, 'obs'),
act=self.act[index],
# act_=self.get(index, 'act'), # stacked action, for RNN
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),

View File

@ -200,14 +200,8 @@ class Collector(object):
return
if isinstance(self.state, list):
self.state[id] = None
elif isinstance(self.state, (dict, Batch)):
for k in self.state.keys():
if isinstance(self.state[k], list):
self.state[k][id] = None
elif isinstance(self.state[k], (torch.Tensor, np.ndarray)):
self.state[k][id] = 0
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
self.state[id] = 0
elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)):
self.state[id] *= 0
def collect(self,
n_step: int = 0,
@ -272,9 +266,18 @@ class Collector(object):
else:
with torch.no_grad():
result = self.policy(batch, self.state)
# save hidden state to policy._state, in order to save into buffer
self.state = result.get('state', None)
self._policy = to_numpy(result.policy) \
if hasattr(result, 'policy') else [{}] * self.env_num
if hasattr(result, 'policy'):
self._policy = to_numpy(result.policy)
if self.state is not None:
self._policy._state = self.state
elif self.state is not None:
self._policy = Batch(_state=self.state)
else:
self._policy = [{}] * self.env_num
self._act = to_numpy(result.act)
if self._action_noise is not None:
self._act += self._action_noise(self._act.shape)