diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index ba8a4d5..5c34352 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -40,7 +40,7 @@ def test_stack(size=5, bufsize=9, stack_num=4): if done: obs = env.reset(1) indice = np.arange(len(buf)) - assert abs(buf.get_stack(indice, 'obs') - np.array([ + assert abs(buf.get(indice, 'obs') - np.array([ [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6 diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 483b01b..5328d1d 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,7 +1,7 @@ import numpy as np from tianshou.policy import BasePolicy from tianshou.env import SubprocVectorEnv -from tianshou.data import Collector, Batch +from tianshou.data import Collector, Batch, ReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -36,21 +36,21 @@ def test_collector(): venv = SubprocVectorEnv(env_fns) policy = MyPolicy() env = env_fns[0]() - c0 = Collector(policy, env) + c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False)) c0.collect(n_step=3) assert equal(c0.buffer.obs[:3], [0, 1, 0]) - assert equal(c0.buffer.obs_next[:3], [1, 2, 1]) + assert equal(c0.buffer[:3].obs_next, [1, 2, 1]) c0.collect(n_episode=3) assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) - assert equal(c0.buffer.obs_next[:8], [1, 2, 1, 2, 1, 2, 1, 2]) - c1 = Collector(policy, venv) + assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) + c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False)) c1.collect(n_step=6) assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) - assert equal(c1.buffer.obs_next[:11], [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) + assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) c1.collect(n_episode=2) assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) - assert equal(c1.buffer.obs_next[11:21], [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) - c2 = Collector(policy, venv) + assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) + c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False)) c2.collect(n_episode=[1, 2, 2, 2]) assert equal(c2.buffer.obs_next[:26], [ 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index f851d4e..6bbc04b 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -72,7 +72,7 @@ def test_drqn(args=get_args()): # collector train_collector = Collector( policy, train_envs, ReplayBuffer( - args.buffer_size, stack_num=args.stack_num)) + args.buffer_size, stack_num=args.stack_num, ignore_obs_next=True)) # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs) # policy.set_eps(1) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index c5edbfe..f0e5d9b 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -41,14 +41,15 @@ class ReplayBuffer(object): >>> batch_data.obs == buf[indice].obs array([ True, True, True, True]) - From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports - frame_stack sampling, typically for RNN usage: + Since version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports + frame_stack sampling (typically for RNN usage) and ignoring storing the + next observation (save memory): :: - >>> buf = ReplayBuffer(size=9, stack_num=4) + >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> for i in range(16): ... done = i % 5 == 0 - ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={}) + ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i, info={}) >>> print(buf) ReplayBuffer( obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], @@ -59,7 +60,7 @@ class ReplayBuffer(object): info: [{} {} {} {} {} {} {} {} {}], ) >>> index = np.arange(len(buf)) - >>> print(buf.get_stack(index, 'obs')) + >>> print(buf.get(index, 'obs')) [[ 7. 7. 8. 9.] [ 7. 8. 9. 10.] [11. 11. 11. 11.] @@ -71,14 +72,15 @@ class ReplayBuffer(object): [ 7. 7. 7. 8.]] >>> # here is another way to get the stacked data >>> # (stack only for obs and obs_next) - >>> abs(buf.get_stack(index, 'obs') - buf[index].obs).sum().sum() + >>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum() 0.0 """ - def __init__(self, size, stack_num=0): + def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs): super().__init__() self._maxsize = size self._stack = stack_num + self._save_s_ = not ignore_obs_next self.reset() def __len__(self): @@ -125,7 +127,7 @@ class ReplayBuffer(object): while True: self.add( buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], - None if buffer.obs_next is None else buffer.obs_next[i], + buffer.obs_next[i] if self._save_s_ else None, buffer.info[i]) i = (i + 1) % len(buffer) if i == begin: @@ -139,7 +141,8 @@ class ReplayBuffer(object): self._add_to_buffer('act', act) self._add_to_buffer('rew', rew) self._add_to_buffer('done', done) - self._add_to_buffer('obs_next', obs_next) + if self._save_s_: + self._add_to_buffer('obs_next', obs_next) self._add_to_buffer('info', info) if self._maxsize > 0: self._size = min(self._size + 1, self._maxsize) @@ -166,19 +169,30 @@ class ReplayBuffer(object): ]) return self[indice], indice - def get_stack(self, indice, key): + def get(self, indice, key): """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure. """ - if self.__dict__.get(key, None) is None: - return None - if self._stack == 0: - return self.__dict__[key][indice] - stack = [] + if not isinstance(indice, np.ndarray): + if np.isscalar(indice): + indice = np.array(indice) + elif isinstance(indice, slice): + indice = np.arange( + 0 if indice.start is None else indice.start, + self._size if indice.stop is None else indice.stop, + 1 if indice.step is None else indice.step) # set last frame done to True last_index = (self._index - 1 + self._size) % self._size last_done, self.done[last_index] = self.done[last_index], True + if key == 'obs_next' and not self._save_s_: + indice += 1 - self.done[indice].astype(np.int) + indice[indice == self._size] = 0 + key = 'obs' + if self._stack == 0: + self.done[last_index] = last_done + return self.__dict__[key][indice] + stack = [] for i in range(self._stack): stack = [self.__dict__[key][indice]] + stack pre_indice = indice - 1 @@ -193,11 +207,11 @@ class ReplayBuffer(object): return the stacked obs and obs_next with shape [batch, len, ...]. """ return Batch( - obs=self.get_stack(index, 'obs'), + obs=self.get(index, 'obs'), act=self.act[index], rew=self.rew[index], done=self.done[index], - obs_next=self.get_stack(index, 'obs_next'), + obs_next=self.get(index, 'obs_next'), info=self.info[index] ) @@ -213,8 +227,8 @@ class ListReplayBuffer(ReplayBuffer): detailed explanation. """ - def __init__(self): - super().__init__(size=0) + def __init__(self, **kwargs): + super().__init__(size=0, ignore_obs_next=False, **kwargs) def _add_to_buffer(self, name, inst): if inst is None: @@ -233,8 +247,8 @@ class ListReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer): """docstring for PrioritizedReplayBuffer""" - def __init__(self, size): - super().__init__(size) + def __init__(self, size, **kwargs): + super().__init__(size, **kwargs) def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): raise NotImplementedError diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 43c5103..8849921 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -22,8 +22,6 @@ class Collector(object): :class:`~tianshou.data.ReplayBuffer`. :param int stat_size: for the moving average of recording speed, defaults to 100. - :param bool store_obs_next: store the next observation to replay buffer or - not, defaults to ``True``. Example: :: @@ -70,8 +68,7 @@ class Collector(object): Please make sure the given environment has a time limitation. """ - def __init__(self, policy, env, buffer=None, stat_size=100, - store_obs_next=True, **kwargs): + def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs): super().__init__() self.env = env self.env_num = 1 @@ -106,7 +103,6 @@ class Collector(object): self.state = None self.step_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size) - self._save_s_ = store_obs_next def reset_buffer(self): """Reset the main data buffer.""" @@ -247,8 +243,7 @@ class Collector(object): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], - 'obs_next': obs_next[i] if self._save_s_ else None, - 'info': self._info[i]} + 'obs_next': obs_next[i], 'info': self._info[i]} if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) @@ -284,8 +279,7 @@ class Collector(object): else: self.buffer.add( self._obs, self._act[0], self._rew, - self._done, obs_next if self._save_s_ else None, - self._info) + self._done, obs_next, self._info) cur_step += 1 if self._done: cur_episode += 1 diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index b6e7bbf..08e2341 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -81,16 +81,17 @@ class DQNPolicy(BasePolicy): returns[buffer.done[now] > 0] = 0 returns = buffer.rew[now] + self._gamma * returns terminal = (indice + self._n_step - 1) % len(buffer) + terminal_data = buffer[terminal] if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - a = self(buffer[terminal], input='obs_next', eps=0).act + a = self(terminal_data, input='obs_next', eps=0).act target_q = self( - buffer[terminal], model='model_old', input='obs_next').logits + terminal_data, model='model_old', input='obs_next').logits if isinstance(target_q, torch.Tensor): target_q = target_q.detach().cpu().numpy() target_q = target_q[np.arange(len(a)), a] else: - target_q = self(buffer[terminal], input='obs_next').logits + target_q = self(terminal_data, input='obs_next').logits if isinstance(target_q, torch.Tensor): target_q = target_q.detach().cpu().numpy() target_q = target_q.max(axis=1) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 9f82d5a..75f3023 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -40,8 +40,6 @@ class PGPolicy(BasePolicy): discount factor, :math:`\gamma \in [0, 1]`. """ batch.returns = self._vanilla_returns(batch) - if getattr(batch, 'obs_next', None) is None: - batch.obs_next = buffer[(indice + 1) % len(buffer)].obs # batch.returns = self._vectorized_returns(batch) return batch