store RNN hidden states in policy._state and add sample_avail in buffer (#19)
This commit is contained in:
parent
60cfc373f8
commit
e0f4862d01
@ -61,11 +61,13 @@ def test_ignore_obs_next(size=10):
|
||||
|
||||
def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
env = MyTestEnv(size)
|
||||
buf = ReplayBuffer(bufsize, stack_num)
|
||||
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
||||
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
||||
obs = env.reset(1)
|
||||
for i in range(15):
|
||||
obs_next, rew, done, info = env.step(1)
|
||||
buf.add(obs, 1, rew, done, None, info)
|
||||
buf2.add(obs, 1, rew, done, None, info)
|
||||
obs = obs_next
|
||||
if done:
|
||||
obs = env.reset(1)
|
||||
@ -75,6 +77,10 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]]))
|
||||
print(buf)
|
||||
_, indice = buf2.sample(0)
|
||||
assert indice == [2]
|
||||
_, indice = buf2.sample(1)
|
||||
assert indice.sum() == 2
|
||||
|
||||
|
||||
def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
|
@ -97,10 +97,11 @@ class Batch:
|
||||
function return 4 arguments, and the last one is ``info``);
|
||||
* ``policy`` the data computed by policy in step :math:`t`;
|
||||
|
||||
:class:`Batch` object can be initialized using wide variety of arguments,
|
||||
starting with the key/value pairs or dictionary, but also list and Numpy
|
||||
arrays of :class:`dict` or Batch instances. In which case, each element
|
||||
is considered as an individual sample and get stacked together:
|
||||
:class:`~tianshou.data.Batch` object can be initialized using wide variety
|
||||
of arguments, starting with the key/value pairs or dictionary, but also
|
||||
list and Numpy arrays of :class:`dict` or Batch instances. In which case,
|
||||
each element is considered as an individual sample and get stacked
|
||||
together:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
@ -113,9 +114,9 @@ class Batch:
|
||||
),
|
||||
)
|
||||
|
||||
:class:`Batch` has the same API as a native Python :class:`dict`. In this
|
||||
regard, one can access to stored data using string key, or iterate over
|
||||
stored data:
|
||||
:class:`~tianshou.data.Batch` has the same API as a native Python
|
||||
:class:`dict`. In this regard, one can access to stored data using string
|
||||
key, or iterate over stored data:
|
||||
::
|
||||
|
||||
>>> from tianshou.data import Batch
|
||||
@ -128,8 +129,8 @@ class Batch:
|
||||
b: [5, 5]
|
||||
|
||||
|
||||
:class:`Batch` is also reproduce partially the Numpy API for arrays. You
|
||||
can access or iterate over the individual samples, if any:
|
||||
:class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for
|
||||
arrays. You can access or iterate over the individual samples, if any:
|
||||
::
|
||||
|
||||
>>> import numpy as np
|
||||
@ -219,11 +220,12 @@ class Batch:
|
||||
>>> len(data[0])
|
||||
TypeError: Object of type 'Batch' has no len()
|
||||
|
||||
Convenience helpers are available to convert in-place the
|
||||
stored data into Numpy arrays or Torch tensors.
|
||||
Convenience helpers are available to convert in-place the stored data into
|
||||
Numpy arrays or Torch tensors.
|
||||
|
||||
Finally, note that Batch instance are serializable and therefore Pickle
|
||||
compatible. This is especially important for distributed sampling.
|
||||
Finally, note that :class:`~tianshou.data.Batch` instance are serializable
|
||||
and therefore Pickle compatible. This is especially important for
|
||||
distributed sampling.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
from typing import Any, Tuple, Union, Optional
|
||||
|
||||
from .batch import Batch, _create_value
|
||||
from tianshou.data.batch import Batch, _create_value
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
@ -91,12 +91,27 @@ class ReplayBuffer:
|
||||
[12. 13. 14. 15.]
|
||||
[ 7. 7. 7. 8.]
|
||||
[ 7. 7. 8. 9.]]
|
||||
|
||||
:param int size: the size of replay buffer.
|
||||
:param int stack_num: the frame-stack sampling argument, should be greater
|
||||
than 1, defaults to 0 (no stacking).
|
||||
:param bool ignore_obs_next: whether to store obs_next, defaults to
|
||||
``False``.
|
||||
:param bool sample_avail: the parameter indicating sampling only available
|
||||
index when using frame-stack sampling method, defaults to ``False``.
|
||||
This feature is not supported in Prioritized Replay Buffer currently.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, stack_num: Optional[int] = 0,
|
||||
ignore_obs_next: bool = False, **kwargs) -> None:
|
||||
ignore_obs_next: bool = False,
|
||||
sample_avail: bool = False, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._stack = stack_num
|
||||
assert stack_num != 1, \
|
||||
'stack_num should greater than 1'
|
||||
self._avail = sample_avail and stack_num > 1
|
||||
self._avail_index = []
|
||||
self._save_s_ = not ignore_obs_next
|
||||
self._index = 0
|
||||
self._size = 0
|
||||
@ -146,7 +161,7 @@ class ReplayBuffer:
|
||||
def add(self,
|
||||
obs: Union[dict, Batch, np.ndarray],
|
||||
act: Union[np.ndarray, float],
|
||||
rew: float,
|
||||
rew: Union[int, float],
|
||||
done: bool,
|
||||
obs_next: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
info: dict = {},
|
||||
@ -165,6 +180,23 @@ class ReplayBuffer:
|
||||
self._add_to_buffer('obs_next', obs_next)
|
||||
self._add_to_buffer('info', info)
|
||||
self._add_to_buffer('policy', policy)
|
||||
|
||||
# maintain available index for frame-stack sampling
|
||||
if self._avail:
|
||||
# update current frame
|
||||
avail = sum(self.done[i] for i in range(
|
||||
self._index - self._stack + 1, self._index)) == 0
|
||||
if self._size < self._stack - 1:
|
||||
avail = False
|
||||
if avail and self._index not in self._avail_index:
|
||||
self._avail_index.append(self._index)
|
||||
elif not avail and self._index in self._avail_index:
|
||||
self._avail_index.remove(self._index)
|
||||
# remove the later available frame because of broken storage
|
||||
t = (self._index + self._stack - 1) % self._maxsize
|
||||
if t in self._avail_index:
|
||||
self._avail_index.remove(t)
|
||||
|
||||
if self._maxsize > 0:
|
||||
self._size = min(self._size + 1, self._maxsize)
|
||||
self._index = (self._index + 1) % self._maxsize
|
||||
@ -175,6 +207,7 @@ class ReplayBuffer:
|
||||
"""Clear all the data in replay buffer."""
|
||||
self._index = 0
|
||||
self._size = 0
|
||||
self._avail_index = []
|
||||
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with size equal to batch_size. \
|
||||
@ -183,12 +216,17 @@ class ReplayBuffer:
|
||||
:return: Sample data and its corresponding index inside the buffer.
|
||||
"""
|
||||
if batch_size > 0:
|
||||
indice = np.random.choice(self._size, batch_size)
|
||||
_all = self._avail_index if self._avail else self._size
|
||||
indice = np.random.choice(_all, batch_size)
|
||||
else:
|
||||
indice = np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(0, self._index),
|
||||
])
|
||||
if self._avail:
|
||||
indice = np.array(self._avail_index)
|
||||
else:
|
||||
indice = np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(0, self._index),
|
||||
])
|
||||
assert len(indice) > 0, 'No available indice can be sampled.'
|
||||
return self[indice], indice
|
||||
|
||||
def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
|
||||
@ -247,11 +285,10 @@ class ReplayBuffer:
|
||||
return Batch(
|
||||
obs=self.get(index, 'obs'),
|
||||
act=self.act[index],
|
||||
# act_=self.get(index, 'act'), # stacked action, for RNN
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.get(index, 'info', stack_num=0),
|
||||
info=self.get(index, 'info'),
|
||||
policy=self.get(index, 'policy')
|
||||
)
|
||||
|
||||
@ -317,7 +354,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
def add(self,
|
||||
obs: Union[dict, np.ndarray],
|
||||
act: Union[np.ndarray, float],
|
||||
rew: float,
|
||||
rew: Union[int, float],
|
||||
done: bool,
|
||||
obs_next: Optional[Union[dict, np.ndarray]] = None,
|
||||
info: dict = {},
|
||||
@ -401,11 +438,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
- self.weight[indice].sum()
|
||||
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
||||
|
||||
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
|
||||
def __getitem__(self, index: Union[
|
||||
slice, int, np.integer, np.ndarray]) -> Batch:
|
||||
return Batch(
|
||||
obs=self.get(index, 'obs'),
|
||||
act=self.act[index],
|
||||
# act_=self.get(index, 'act'), # stacked action, for RNN
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
|
@ -200,14 +200,8 @@ class Collector(object):
|
||||
return
|
||||
if isinstance(self.state, list):
|
||||
self.state[id] = None
|
||||
elif isinstance(self.state, (dict, Batch)):
|
||||
for k in self.state.keys():
|
||||
if isinstance(self.state[k], list):
|
||||
self.state[k][id] = None
|
||||
elif isinstance(self.state[k], (torch.Tensor, np.ndarray)):
|
||||
self.state[k][id] = 0
|
||||
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
|
||||
self.state[id] = 0
|
||||
elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)):
|
||||
self.state[id] *= 0
|
||||
|
||||
def collect(self,
|
||||
n_step: int = 0,
|
||||
@ -272,9 +266,18 @@ class Collector(object):
|
||||
else:
|
||||
with torch.no_grad():
|
||||
result = self.policy(batch, self.state)
|
||||
|
||||
# save hidden state to policy._state, in order to save into buffer
|
||||
self.state = result.get('state', None)
|
||||
self._policy = to_numpy(result.policy) \
|
||||
if hasattr(result, 'policy') else [{}] * self.env_num
|
||||
if hasattr(result, 'policy'):
|
||||
self._policy = to_numpy(result.policy)
|
||||
if self.state is not None:
|
||||
self._policy._state = self.state
|
||||
elif self.state is not None:
|
||||
self._policy = Batch(_state=self.state)
|
||||
else:
|
||||
self._policy = [{}] * self.env_num
|
||||
|
||||
self._act = to_numpy(result.act)
|
||||
if self._action_noise is not None:
|
||||
self._act += self._action_noise(self._act.shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user