reserve 'policy' keyword in replay buffer

This commit is contained in:
Trinkle23897 2020-04-29 17:48:48 +08:00
parent e58fc78546
commit 134f787e24
5 changed files with 63 additions and 22 deletions

View File

@ -24,7 +24,7 @@ class Batch(object):
)
In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 6 reserved keys in
current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
@ -34,6 +34,7 @@ class Batch(object):
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
function return 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
:class:`~tianshou.data.Batch` has other methods, including
:meth:`~tianshou.data.Batch.__getitem__`,
@ -129,6 +130,14 @@ class Batch(object):
return sorted([i for i in self.__dict__ if i[0] != '_'] +
list(self._meta))
def to_numpy(self):
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
operation.
"""
for k in self.__dict__:
if isinstance(self.__dict__[k], torch.Tensor):
self.__dict__[k] = self.__dict__[k].cpu().numpy()
def append(self, batch):
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'

View File

@ -42,26 +42,30 @@ class ReplayBuffer(object):
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
Since version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
frame_stack sampling (typically for RNN usage) and ignoring storing the
next observation (save memory):
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
observation (save memory in atari tasks), and multi-modal observation (see
issue#38, need version >= 0.2.3):
::
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1)
>>> print(buf)
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
... obs_next={'id': i + 1})
>>> print(buf) # you can see obs_next is not saved in buf
ReplayBuffer(
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
act: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
rew: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
done: [0. 1. 0. 0. 0. 0. 1. 0. 0.],
obs_next: [0. 0. 0. 0. 0. 0. 0. 0. 0.],
info: [{} {} {} {} {} {} {} {} {}],
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
info: array([{}, {}, {}, {}, {}, {}, {}, {}, {}], dtype=object),
obs: Batch(
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
),
policy: Batch(),
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs'))
>>> print(buf.get(index, 'obs').id)
[[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.]
[11. 11. 11. 11.]
@ -73,10 +77,10 @@ class ReplayBuffer(object):
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0.0
>>> # we can get obs_next through __getitem__, even if it doesn't store
>>> print(buf[:].obs_next)
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7. 8. 9. 10.]
[ 7. 8. 9. 10.]
[11. 11. 11. 12.]
@ -169,7 +173,7 @@ class ReplayBuffer(object):
self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
buffer.obs_next[i] if self._save_s_ else None,
buffer.info[i])
buffer.info[i], buffer.policy[i])
i = (i + 1) % len(buffer)
if i == begin:
break

View File

@ -175,6 +175,20 @@ class Collector(object):
isinstance(self.state, np.ndarray):
self.state[id] = 0
def _to_numpy(self, x):
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
elif isinstance(x, dict):
for k in x:
if isinstance(x[k], torch.Tensor):
x[k] = x[k].cpu().numpy()
return x
elif isinstance(x, Batch):
x.to_numpy()
return x
return x
def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
"""Collect a specified number of step or episode.
@ -232,8 +246,11 @@ class Collector(object):
with torch.no_grad():
result = self.policy(batch_data, self.state)
self.state = result.state if hasattr(result, 'state') else None
self._policy = self._to_numpy(result.policy) \
if hasattr(result, 'policy') \
else [{}] * self.env_num if self._multi_env else {}
if isinstance(result.act, torch.Tensor):
self._act = result.act.detach().cpu().numpy()
self._act = self._to_numpy(result.act)
elif not isinstance(self._act, np.ndarray):
self._act = np.array(result.act)
else:
@ -253,7 +270,8 @@ class Collector(object):
data = {
'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i],
'obs_next': obs_next[i], 'info': self._info[i]}
'obs_next': obs_next[i], 'info': self._info[i],
'policy': self._policy[i]}
if self._cached_buf:
warning_count += 1
self._cached_buf[i].add(**data)
@ -292,7 +310,7 @@ class Collector(object):
if self.buffer is not None:
self.buffer.add(
self._obs, self._act[0], self._rew,
self._done, obs_next, self._info)
self._done, obs_next, self._info, self._policy)
cur_step += 1
if self._done:
cur_episode += 1

View File

@ -65,6 +65,17 @@ class BasePolicy(ABC, nn.Module):
# some code
return Batch(logits=..., act=..., state=None, dist=...)
After version >= 0.2.3, the keyword "policy" is reserverd and the
corresponding data will be stored into the replay buffer in numpy. For
instance,
::
# some code
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
# and in the sampled data batch, you can directly call
# batch.policy.log_prob to get your data, although it is stored in
# np.ndarray.
"""
pass

View File

@ -92,8 +92,7 @@ class A2CPolicy(PGPolicy):
for _ in range(repeat):
for b in batch.split(batch_size):
self.optim.zero_grad()
result = self(b)
dist = result.dist
dist = self(b).dist
v = self.critic(b.obs)
a = torch.tensor(b.act, device=v.device)
r = torch.tensor(b.returns, device=v.device)