From 134f787e24d4fdb562d6c1ab81160d72db7a14ae Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 29 Apr 2020 17:48:48 +0800 Subject: [PATCH] reserve 'policy' keyword in replay buffer --- tianshou/data/batch.py | 11 +++++++++- tianshou/data/buffer.py | 36 ++++++++++++++++++-------------- tianshou/data/collector.py | 24 ++++++++++++++++++--- tianshou/policy/base.py | 11 ++++++++++ tianshou/policy/modelfree/a2c.py | 3 +-- 5 files changed, 63 insertions(+), 22 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 728c827..88169f9 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -24,7 +24,7 @@ class Batch(object): ) In short, you can define a :class:`Batch` with any key-value pair. The - current implementation of Tianshou typically use 6 reserved keys in + current implementation of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: * ``obs`` the observation of step :math:`t` ; @@ -34,6 +34,7 @@ class Batch(object): * ``obs_next`` the observation of step :math:`t+1` ; * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\ function return 4 arguments, and the last one is ``info``); + * ``policy`` the data computed by policy in step :math:`t`; :class:`~tianshou.data.Batch` has other methods, including :meth:`~tianshou.data.Batch.__getitem__`, @@ -129,6 +130,14 @@ class Batch(object): return sorted([i for i in self.__dict__ if i[0] != '_'] + list(self._meta)) + def to_numpy(self): + """Change all torch.Tensor to numpy.ndarray. This is an inplace + operation. + """ + for k in self.__dict__: + if isinstance(self.__dict__[k], torch.Tensor): + self.__dict__[k] = self.__dict__[k].cpu().numpy() + def append(self, batch): """Append a :class:`~tianshou.data.Batch` object to current batch.""" assert isinstance(batch, Batch), 'Only append Batch is allowed!' diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 511d440..89e15ad 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -42,26 +42,30 @@ class ReplayBuffer(object): >>> batch_data.obs == buf[indice].obs array([ True, True, True, True]) - Since version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports - frame_stack sampling (typically for RNN usage) and ignoring storing the - next observation (save memory): + :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling + (typically for RNN usage, see issue#19), ignoring storing the next + observation (save memory in atari tasks), and multi-modal observation (see + issue#38, need version >= 0.2.3): :: >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> for i in range(16): ... done = i % 5 == 0 - ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1) - >>> print(buf) + ... buf.add(obs={'id': i}, act=i, rew=i, done=done, + ... obs_next={'id': i + 1}) + >>> print(buf) # you can see obs_next is not saved in buf ReplayBuffer( - obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], - act: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], - rew: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], - done: [0. 1. 0. 0. 0. 0. 1. 0. 0.], - obs_next: [0. 0. 0. 0. 0. 0. 0. 0. 0.], - info: [{} {} {} {} {} {} {} {} {}], + act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), + done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), + info: array([{}, {}, {}, {}, {}, {}, {}, {}, {}], dtype=object), + obs: Batch( + id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), + ), + policy: Batch(), + rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), ) >>> index = np.arange(len(buf)) - >>> print(buf.get(index, 'obs')) + >>> print(buf.get(index, 'obs').id) [[ 7. 7. 8. 9.] [ 7. 8. 9. 10.] [11. 11. 11. 11.] @@ -73,10 +77,10 @@ class ReplayBuffer(object): [ 7. 7. 7. 8.]] >>> # here is another way to get the stacked data >>> # (stack only for obs and obs_next) - >>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum() + >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() 0.0 - >>> # we can get obs_next through __getitem__, even if it doesn't store - >>> print(buf[:].obs_next) + >>> # we can get obs_next through __getitem__, even if it doesn't exist + >>> print(buf[:].obs_next.id) [[ 7. 8. 9. 10.] [ 7. 8. 9. 10.] [11. 11. 11. 12.] @@ -169,7 +173,7 @@ class ReplayBuffer(object): self.add( buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], buffer.obs_next[i] if self._save_s_ else None, - buffer.info[i]) + buffer.info[i], buffer.policy[i]) i = (i + 1) % len(buffer) if i == begin: break diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 92a5d99..9b5478b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -175,6 +175,20 @@ class Collector(object): isinstance(self.state, np.ndarray): self.state[id] = 0 + def _to_numpy(self, x): + """Return an object without torch.Tensor.""" + if isinstance(x, torch.Tensor): + return x.cpu().numpy() + elif isinstance(x, dict): + for k in x: + if isinstance(x[k], torch.Tensor): + x[k] = x[k].cpu().numpy() + return x + elif isinstance(x, Batch): + x.to_numpy() + return x + return x + def collect(self, n_step=0, n_episode=0, render=None, log_fn=None): """Collect a specified number of step or episode. @@ -232,8 +246,11 @@ class Collector(object): with torch.no_grad(): result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None + self._policy = self._to_numpy(result.policy) \ + if hasattr(result, 'policy') \ + else [{}] * self.env_num if self._multi_env else {} if isinstance(result.act, torch.Tensor): - self._act = result.act.detach().cpu().numpy() + self._act = self._to_numpy(result.act) elif not isinstance(self._act, np.ndarray): self._act = np.array(result.act) else: @@ -253,7 +270,8 @@ class Collector(object): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], - 'obs_next': obs_next[i], 'info': self._info[i]} + 'obs_next': obs_next[i], 'info': self._info[i], + 'policy': self._policy[i]} if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) @@ -292,7 +310,7 @@ class Collector(object): if self.buffer is not None: self.buffer.add( self._obs, self._act[0], self._rew, - self._done, obs_next, self._info) + self._done, obs_next, self._info, self._policy) cur_step += 1 if self._done: cur_episode += 1 diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index a11a04c..ff0b599 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -65,6 +65,17 @@ class BasePolicy(ABC, nn.Module): # some code return Batch(logits=..., act=..., state=None, dist=...) + + After version >= 0.2.3, the keyword "policy" is reserverd and the + corresponding data will be stored into the replay buffer in numpy. For + instance, + :: + + # some code + return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) + # and in the sampled data batch, you can directly call + # batch.policy.log_prob to get your data, although it is stored in + # np.ndarray. """ pass diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 9a50712..b954cfb 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -92,8 +92,7 @@ class A2CPolicy(PGPolicy): for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() - result = self(b) - dist = result.dist + dist = self(b).dist v = self.critic(b.obs) a = torch.tensor(b.act, device=v.device) r = torch.tensor(b.returns, device=v.device)