diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 10a1551..33d166e 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -11,7 +11,6 @@ jobs: strategy: matrix: python-version: [3.6, 3.7, 3.8] - steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/test/base/env.py b/test/base/env.py index d2fd671..5a2294b 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -8,9 +8,9 @@ class MyTestEnv(gym.Env): self.sleep = sleep self.reset() - def reset(self): + def reset(self, state=0): self.done = False - self.index = 0 + self.index = state return self.index def step(self, action): diff --git a/test/base/test_batch.py b/test/base/test_batch.py index d5b22ca..53a2cdc 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -17,6 +17,7 @@ def test_batch(): batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, permute=False)): assert b.obs == batch[i].obs + print(batch) if __name__ == '__main__': diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 38d3bc5..ba8a4d5 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,3 +1,4 @@ +import numpy as np from tianshou.data import ReplayBuffer if __name__ == '__main__': @@ -28,5 +29,24 @@ def test_replaybuffer(size=10, bufsize=20): assert buf2[-1].obs == buf[4].obs +def test_stack(size=5, bufsize=9, stack_num=4): + env = MyTestEnv(size) + buf = ReplayBuffer(bufsize, stack_num) + obs = env.reset(1) + for i in range(15): + obs_next, rew, done, info = env.step(1) + buf.add(obs, 1, rew, done, None, info) + obs = obs_next + if done: + obs = env.reset(1) + indice = np.arange(len(buf)) + assert abs(buf.get_stack(indice, 'obs') - np.array([ + [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], + [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], + [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6 + print(buf) + + if __name__ == '__main__': test_replaybuffer() + test_stack() diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 76a6106..f851d4e 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument('--stack-num', type=int, default=4) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 5bcdeab..081f48a 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -31,27 +31,29 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1): def test_fn(size=2560): policy = PGPolicy(None, None, None, discount_factor=0.1) + buf = ReplayBuffer(100) + buf.add(1, 1, 1, 1, 1) fn = policy.process_fn # fn = compute_return_base batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), ) - batch = fn(batch, None, None) + batch = fn(batch, buf, 0) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) assert abs(batch.returns - ans).sum() <= 1e-5 batch = Batch( done=np.array([0, 1, 0, 1, 0, 1, 0.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) - batch = fn(batch, None, None) + batch = fn(batch, buf, 0) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) assert abs(batch.returns - ans).sum() <= 1e-5 batch = Batch( done=np.array([0, 1, 0, 1, 0, 0, 1.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) - batch = fn(batch, None, None) + batch = fn(batch, buf, 0) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) assert abs(batch.returns - ans).sum() <= 1e-5 if __name__ == '__main__': @@ -66,7 +68,7 @@ def test_fn(size=2560): print(f'vanilla: {(time.time() - t) / cnt}') t = time.time() for _ in range(cnt): - policy.process_fn(batch, None, None) + policy.process_fn(batch, buf, 0) print(f'policy: {(time.time() - t) / cnt}') @@ -147,5 +149,5 @@ def test_pg(args=get_args()): if __name__ == '__main__': - # test_fn() + test_fn() test_pg() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 301de9a..376e7d2 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -73,6 +73,22 @@ class Batch(object): b.__dict__.update(**{k: self.__dict__[k][index]}) return b + def __repr__(self): + """Return str(self).""" + s = self.__class__.__name__ + '(\n' + flag = False + for k in self.__dict__.keys(): + if k[0] != '_' and self.__dict__[k] is not None: + rpl = '\n' + ' ' * (6 + len(k)) + obj = str(self.__dict__[k]).replace('\n', rpl) + s += f' {k}: {obj},\n' + flag = True + if flag: + s += ')\n' + else: + s = self.__class__.__name__ + '()\n' + return s + def append(self, batch): """Append a :class:`~tianshou.data.Batch` object to current batch.""" assert isinstance(batch, Batch), 'Only append Batch is allowed!' diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 91fc9c3..1146cd8 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -39,6 +39,34 @@ class ReplayBuffer(object): >>> batch_data, indice = buf.sample(batch_size=4) >>> batch_data.obs == buf[indice].obs array([ True, True, True, True]) + + From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports + frame_stack sampling, typically for RNN usage: + :: + + >>> buf = ReplayBuffer(size=9, stack_num=4) + >>> for i in range(16): + ... done = i % 5 == 0 + ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={}) + >>> print(buf.obs) + [ 9. 10. 11. 12. 13. 14. 15. 7. 8.] + >>> print(buf.done) + [0. 1. 0. 0. 0. 0. 1. 0. 0.] + >>> index = np.arange(len(buf)) + >>> print(buf.get_stack(index, 'obs')) + [[ 7. 7. 8. 9.] + [ 7. 8. 9. 10.] + [11. 11. 11. 11.] + [11. 11. 11. 12.] + [11. 11. 12. 13.] + [11. 12. 13. 14.] + [12. 13. 14. 15.] + [ 7. 7. 7. 7.] + [ 7. 7. 7. 8.]] + >>> # here is another way to get the stacked data + >>> # (stack only for obs and obs_next) + >>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs)) + 0.0 """ def __init__(self, size, stack_num=0): @@ -51,8 +79,26 @@ class ReplayBuffer(object): """Return len(self).""" return self._size + def __repr__(self): + """Return str(self).""" + s = self.__class__.__name__ + '(\n' + flag = False + for k in self.__dict__.keys(): + if k[0] != '_' and self.__dict__[k] is not None: + rpl = '\n' + ' ' * (6 + len(k)) + obj = str(self.__dict__[k]).replace('\n', rpl) + s += f' {k}: {obj},\n' + flag = True + if flag: + s += ')\n' + else: + s = self.__class__.__name__ + '()\n' + return s + def _add_to_buffer(self, name, inst): if inst is None: + if getattr(self, name, None) is None: + self.__dict__[name] = None return if self.__dict__.get(name, None) is None: if isinstance(inst, np.ndarray): @@ -72,13 +118,14 @@ class ReplayBuffer(object): i = begin = buffer._index % len(buffer) while True: self.add( - buffer.obs[i], buffer.act[i], buffer.rew[i], - buffer.done[i], buffer.obs_next[i], buffer.info[i]) + buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], + None if buffer.obs_next is None else buffer.obs_next[i], + buffer.info[i]) i = (i + 1) % len(buffer) if i == begin: break - def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): + def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None): """Add a batch of data into replay buffer.""" assert isinstance(info, dict), \ 'You should return a dict in the last argument of env.step().' @@ -97,7 +144,6 @@ class ReplayBuffer(object): def reset(self): """Clear all the data in replay buffer.""" self._index = self._size = 0 - self.indice = [] def sample(self, batch_size): """Get a random sample from buffer with size equal to batch_size. \ @@ -114,16 +160,26 @@ class ReplayBuffer(object): ]) return self[indice], indice - def _get_stack(self, indice, key): + def get_stack(self, indice, key): + """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], + where s is self.key, t is indice. The stack_num (here equals to 4) is + given from buffer initialization procedure. + """ if self.__dict__.get(key, None) is None: return None if self._stack == 0: return self.__dict__[key][indice] stack = [] + # set last frame done to True + last_index = (self._index - 1 + self._size) % self._size + last_done, self.done[last_index] = self.done[last_index], True for i in range(self._stack): stack = [self.__dict__[key][indice]] + stack - indice = indice - 1 + self.done[indice - 1].astype(np.int) - indice[indice == -1] = self._size - 1 + pre_indice = indice - 1 + pre_indice[pre_indice == -1] = self._size - 1 + indice = pre_indice + self.done[pre_indice].astype(np.int) + indice[indice == self._size] = 0 + self.done[last_index] = last_done return np.stack(stack, axis=1) def __getitem__(self, index): @@ -131,11 +187,11 @@ class ReplayBuffer(object): return the stacked obs and obs_next with shape [batch, len, ...]. """ return Batch( - obs=self._get_stack(index, 'obs'), + obs=self.get_stack(index, 'obs'), act=self.act[index], rew=self.rew[index], done=self.done[index], - obs_next=self._get_stack(index, 'obs_next'), + obs_next=self.get_stack(index, 'obs_next'), info=self.info[index] ) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index ca1bd19..43c5103 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -2,10 +2,10 @@ import time import torch import warnings import numpy as np -from tianshou.env import BaseVectorEnv -from tianshou.data import Batch, ReplayBuffer, \ - ListReplayBuffer + from tianshou.utils import MovAvg +from tianshou.env import BaseVectorEnv +from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer class Collector(object): @@ -22,8 +22,8 @@ class Collector(object): :class:`~tianshou.data.ReplayBuffer`. :param int stat_size: for the moving average of recording speed, defaults to 100. - :param bool store_obs_next: whether to store the obs_next to replay - buffer, defaults to ``True``. + :param bool store_obs_next: store the next observation to replay buffer or + not, defaults to ``True``. Example: :: @@ -302,7 +302,7 @@ class Collector(object): self._obs = obs_next if self._multi_env: cur_episode = sum(cur_episode) - duration = time.time() - start_time + duration = max(time.time() - start_time, 1e-9) self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step diff --git a/tianshou/policy/pg.py b/tianshou/policy/pg.py index 47b397d..63c05c7 100644 --- a/tianshou/policy/pg.py +++ b/tianshou/policy/pg.py @@ -35,6 +35,8 @@ class PGPolicy(BasePolicy): discount factor, :math:`\gamma \in [0, 1]`. """ batch.returns = self._vanilla_returns(batch) + if getattr(batch, 'obs_next', None) is None: + batch.obs_next = buffer[(indice + 1) % len(buffer)].obs # batch.returns = self._vectorized_returns(batch) return batch