diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index bc374b8..b4eac93 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -61,11 +61,13 @@ def test_ignore_obs_next(size=10): def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) - buf = ReplayBuffer(bufsize, stack_num) + buf = ReplayBuffer(bufsize, stack_num=stack_num) + buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) obs = env.reset(1) for i in range(15): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) + buf2.add(obs, 1, rew, done, None, info) obs = obs_next if done: obs = env.reset(1) @@ -75,6 +77,10 @@ def test_stack(size=5, bufsize=9, stack_num=4): [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])) print(buf) + _, indice = buf2.sample(0) + assert indice == [2] + _, indice = buf2.sample(1) + assert indice.sum() == 2 def test_priortized_replaybuffer(size=32, bufsize=15): diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 2d9b11b..65c638a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -97,10 +97,11 @@ class Batch: function return 4 arguments, and the last one is ``info``); * ``policy`` the data computed by policy in step :math:`t`; - :class:`Batch` object can be initialized using wide variety of arguments, - starting with the key/value pairs or dictionary, but also list and Numpy - arrays of :class:`dict` or Batch instances. In which case, each element - is considered as an individual sample and get stacked together: + :class:`~tianshou.data.Batch` object can be initialized using wide variety + of arguments, starting with the key/value pairs or dictionary, but also + list and Numpy arrays of :class:`dict` or Batch instances. In which case, + each element is considered as an individual sample and get stacked + together: :: >>> import numpy as np @@ -113,9 +114,9 @@ class Batch: ), ) - :class:`Batch` has the same API as a native Python :class:`dict`. In this - regard, one can access to stored data using string key, or iterate over - stored data: + :class:`~tianshou.data.Batch` has the same API as a native Python + :class:`dict`. In this regard, one can access to stored data using string + key, or iterate over stored data: :: >>> from tianshou.data import Batch @@ -128,8 +129,8 @@ class Batch: b: [5, 5] - :class:`Batch` is also reproduce partially the Numpy API for arrays. You - can access or iterate over the individual samples, if any: + :class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for + arrays. You can access or iterate over the individual samples, if any: :: >>> import numpy as np @@ -219,11 +220,12 @@ class Batch: >>> len(data[0]) TypeError: Object of type 'Batch' has no len() - Convenience helpers are available to convert in-place the - stored data into Numpy arrays or Torch tensors. + Convenience helpers are available to convert in-place the stored data into + Numpy arrays or Torch tensors. - Finally, note that Batch instance are serializable and therefore Pickle - compatible. This is especially important for distributed sampling. + Finally, note that :class:`~tianshou.data.Batch` instance are serializable + and therefore Pickle compatible. This is especially important for + distributed sampling. """ def __init__(self, diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index c050fd5..7eefc50 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,7 +1,7 @@ import numpy as np from typing import Any, Tuple, Union, Optional -from .batch import Batch, _create_value +from tianshou.data.batch import Batch, _create_value class ReplayBuffer: @@ -91,12 +91,27 @@ class ReplayBuffer: [12. 13. 14. 15.] [ 7. 7. 7. 8.] [ 7. 7. 8. 9.]] + + :param int size: the size of replay buffer. + :param int stack_num: the frame-stack sampling argument, should be greater + than 1, defaults to 0 (no stacking). + :param bool ignore_obs_next: whether to store obs_next, defaults to + ``False``. + :param bool sample_avail: the parameter indicating sampling only available + index when using frame-stack sampling method, defaults to ``False``. + This feature is not supported in Prioritized Replay Buffer currently. """ + def __init__(self, size: int, stack_num: Optional[int] = 0, - ignore_obs_next: bool = False, **kwargs) -> None: + ignore_obs_next: bool = False, + sample_avail: bool = False, **kwargs) -> None: super().__init__() self._maxsize = size self._stack = stack_num + assert stack_num != 1, \ + 'stack_num should greater than 1' + self._avail = sample_avail and stack_num > 1 + self._avail_index = [] self._save_s_ = not ignore_obs_next self._index = 0 self._size = 0 @@ -146,7 +161,7 @@ class ReplayBuffer: def add(self, obs: Union[dict, Batch, np.ndarray], act: Union[np.ndarray, float], - rew: float, + rew: Union[int, float], done: bool, obs_next: Optional[Union[dict, Batch, np.ndarray]] = None, info: dict = {}, @@ -165,6 +180,23 @@ class ReplayBuffer: self._add_to_buffer('obs_next', obs_next) self._add_to_buffer('info', info) self._add_to_buffer('policy', policy) + + # maintain available index for frame-stack sampling + if self._avail: + # update current frame + avail = sum(self.done[i] for i in range( + self._index - self._stack + 1, self._index)) == 0 + if self._size < self._stack - 1: + avail = False + if avail and self._index not in self._avail_index: + self._avail_index.append(self._index) + elif not avail and self._index in self._avail_index: + self._avail_index.remove(self._index) + # remove the later available frame because of broken storage + t = (self._index + self._stack - 1) % self._maxsize + if t in self._avail_index: + self._avail_index.remove(t) + if self._maxsize > 0: self._size = min(self._size + 1, self._maxsize) self._index = (self._index + 1) % self._maxsize @@ -175,6 +207,7 @@ class ReplayBuffer: """Clear all the data in replay buffer.""" self._index = 0 self._size = 0 + self._avail_index = [] def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with size equal to batch_size. \ @@ -183,12 +216,17 @@ class ReplayBuffer: :return: Sample data and its corresponding index inside the buffer. """ if batch_size > 0: - indice = np.random.choice(self._size, batch_size) + _all = self._avail_index if self._avail else self._size + indice = np.random.choice(_all, batch_size) else: - indice = np.concatenate([ - np.arange(self._index, self._size), - np.arange(0, self._index), - ]) + if self._avail: + indice = np.array(self._avail_index) + else: + indice = np.concatenate([ + np.arange(self._index, self._size), + np.arange(0, self._index), + ]) + assert len(indice) > 0, 'No available indice can be sampled.' return self[indice], indice def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, @@ -247,11 +285,10 @@ class ReplayBuffer: return Batch( obs=self.get(index, 'obs'), act=self.act[index], - # act_=self.get(index, 'act'), # stacked action, for RNN rew=self.rew[index], done=self.done[index], obs_next=self.get(index, 'obs_next'), - info=self.get(index, 'info', stack_num=0), + info=self.get(index, 'info'), policy=self.get(index, 'policy') ) @@ -317,7 +354,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): def add(self, obs: Union[dict, np.ndarray], act: Union[np.ndarray, float], - rew: float, + rew: Union[int, float], done: bool, obs_next: Optional[Union[dict, np.ndarray]] = None, info: dict = {}, @@ -401,11 +438,11 @@ class PrioritizedReplayBuffer(ReplayBuffer): - self.weight[indice].sum() self.weight[indice] = np.power(np.abs(new_weight), self._alpha) - def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch: + def __getitem__(self, index: Union[ + slice, int, np.integer, np.ndarray]) -> Batch: return Batch( obs=self.get(index, 'obs'), act=self.act[index], - # act_=self.get(index, 'act'), # stacked action, for RNN rew=self.rew[index], done=self.done[index], obs_next=self.get(index, 'obs_next'), diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index cc6c7f5..4b72308 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -200,14 +200,8 @@ class Collector(object): return if isinstance(self.state, list): self.state[id] = None - elif isinstance(self.state, (dict, Batch)): - for k in self.state.keys(): - if isinstance(self.state[k], list): - self.state[k][id] = None - elif isinstance(self.state[k], (torch.Tensor, np.ndarray)): - self.state[k][id] = 0 - elif isinstance(self.state, (torch.Tensor, np.ndarray)): - self.state[id] = 0 + elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)): + self.state[id] *= 0 def collect(self, n_step: int = 0, @@ -272,9 +266,18 @@ class Collector(object): else: with torch.no_grad(): result = self.policy(batch, self.state) + + # save hidden state to policy._state, in order to save into buffer self.state = result.get('state', None) - self._policy = to_numpy(result.policy) \ - if hasattr(result, 'policy') else [{}] * self.env_num + if hasattr(result, 'policy'): + self._policy = to_numpy(result.policy) + if self.state is not None: + self._policy._state = self.state + elif self.state is not None: + self._policy = Batch(_state=self.state) + else: + self._policy = [{}] * self.env_num + self._act = to_numpy(result.act) if self._action_noise is not None: self._act += self._action_noise(self._act.shape)