reserve 'policy' keyword in replay buffer
This commit is contained in:
parent
e58fc78546
commit
134f787e24
@ -24,7 +24,7 @@ class Batch(object):
|
||||
)
|
||||
|
||||
In short, you can define a :class:`Batch` with any key-value pair. The
|
||||
current implementation of Tianshou typically use 6 reserved keys in
|
||||
current implementation of Tianshou typically use 7 reserved keys in
|
||||
:class:`~tianshou.data.Batch`:
|
||||
|
||||
* ``obs`` the observation of step :math:`t` ;
|
||||
@ -34,6 +34,7 @@ class Batch(object):
|
||||
* ``obs_next`` the observation of step :math:`t+1` ;
|
||||
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
|
||||
function return 4 arguments, and the last one is ``info``);
|
||||
* ``policy`` the data computed by policy in step :math:`t`;
|
||||
|
||||
:class:`~tianshou.data.Batch` has other methods, including
|
||||
:meth:`~tianshou.data.Batch.__getitem__`,
|
||||
@ -129,6 +130,14 @@ class Batch(object):
|
||||
return sorted([i for i in self.__dict__ if i[0] != '_'] +
|
||||
list(self._meta))
|
||||
|
||||
def to_numpy(self):
|
||||
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
||||
operation.
|
||||
"""
|
||||
for k in self.__dict__:
|
||||
if isinstance(self.__dict__[k], torch.Tensor):
|
||||
self.__dict__[k] = self.__dict__[k].cpu().numpy()
|
||||
|
||||
def append(self, batch):
|
||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||
|
@ -42,26 +42,30 @@ class ReplayBuffer(object):
|
||||
>>> batch_data.obs == buf[indice].obs
|
||||
array([ True, True, True, True])
|
||||
|
||||
Since version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
|
||||
frame_stack sampling (typically for RNN usage) and ignoring storing the
|
||||
next observation (save memory):
|
||||
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
||||
(typically for RNN usage, see issue#19), ignoring storing the next
|
||||
observation (save memory in atari tasks), and multi-modal observation (see
|
||||
issue#38, need version >= 0.2.3):
|
||||
::
|
||||
|
||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||
>>> for i in range(16):
|
||||
... done = i % 5 == 0
|
||||
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1)
|
||||
>>> print(buf)
|
||||
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
|
||||
... obs_next={'id': i + 1})
|
||||
>>> print(buf) # you can see obs_next is not saved in buf
|
||||
ReplayBuffer(
|
||||
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
||||
act: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
||||
rew: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
||||
done: [0. 1. 0. 0. 0. 0. 1. 0. 0.],
|
||||
obs_next: [0. 0. 0. 0. 0. 0. 0. 0. 0.],
|
||||
info: [{} {} {} {} {} {} {} {} {}],
|
||||
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
|
||||
info: array([{}, {}, {}, {}, {}, {}, {}, {}, {}], dtype=object),
|
||||
obs: Batch(
|
||||
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
),
|
||||
policy: Batch(),
|
||||
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||
)
|
||||
>>> index = np.arange(len(buf))
|
||||
>>> print(buf.get(index, 'obs'))
|
||||
>>> print(buf.get(index, 'obs').id)
|
||||
[[ 7. 7. 8. 9.]
|
||||
[ 7. 8. 9. 10.]
|
||||
[11. 11. 11. 11.]
|
||||
@ -73,10 +77,10 @@ class ReplayBuffer(object):
|
||||
[ 7. 7. 7. 8.]]
|
||||
>>> # here is another way to get the stacked data
|
||||
>>> # (stack only for obs and obs_next)
|
||||
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
|
||||
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
||||
0.0
|
||||
>>> # we can get obs_next through __getitem__, even if it doesn't store
|
||||
>>> print(buf[:].obs_next)
|
||||
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
||||
>>> print(buf[:].obs_next.id)
|
||||
[[ 7. 8. 9. 10.]
|
||||
[ 7. 8. 9. 10.]
|
||||
[11. 11. 11. 12.]
|
||||
@ -169,7 +173,7 @@ class ReplayBuffer(object):
|
||||
self.add(
|
||||
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
|
||||
buffer.obs_next[i] if self._save_s_ else None,
|
||||
buffer.info[i])
|
||||
buffer.info[i], buffer.policy[i])
|
||||
i = (i + 1) % len(buffer)
|
||||
if i == begin:
|
||||
break
|
||||
|
@ -175,6 +175,20 @@ class Collector(object):
|
||||
isinstance(self.state, np.ndarray):
|
||||
self.state[id] = 0
|
||||
|
||||
def _to_numpy(self, x):
|
||||
"""Return an object without torch.Tensor."""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.cpu().numpy()
|
||||
elif isinstance(x, dict):
|
||||
for k in x:
|
||||
if isinstance(x[k], torch.Tensor):
|
||||
x[k] = x[k].cpu().numpy()
|
||||
return x
|
||||
elif isinstance(x, Batch):
|
||||
x.to_numpy()
|
||||
return x
|
||||
return x
|
||||
|
||||
def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
@ -232,8 +246,11 @@ class Collector(object):
|
||||
with torch.no_grad():
|
||||
result = self.policy(batch_data, self.state)
|
||||
self.state = result.state if hasattr(result, 'state') else None
|
||||
self._policy = self._to_numpy(result.policy) \
|
||||
if hasattr(result, 'policy') \
|
||||
else [{}] * self.env_num if self._multi_env else {}
|
||||
if isinstance(result.act, torch.Tensor):
|
||||
self._act = result.act.detach().cpu().numpy()
|
||||
self._act = self._to_numpy(result.act)
|
||||
elif not isinstance(self._act, np.ndarray):
|
||||
self._act = np.array(result.act)
|
||||
else:
|
||||
@ -253,7 +270,8 @@ class Collector(object):
|
||||
data = {
|
||||
'obs': self._obs[i], 'act': self._act[i],
|
||||
'rew': self._rew[i], 'done': self._done[i],
|
||||
'obs_next': obs_next[i], 'info': self._info[i]}
|
||||
'obs_next': obs_next[i], 'info': self._info[i],
|
||||
'policy': self._policy[i]}
|
||||
if self._cached_buf:
|
||||
warning_count += 1
|
||||
self._cached_buf[i].add(**data)
|
||||
@ -292,7 +310,7 @@ class Collector(object):
|
||||
if self.buffer is not None:
|
||||
self.buffer.add(
|
||||
self._obs, self._act[0], self._rew,
|
||||
self._done, obs_next, self._info)
|
||||
self._done, obs_next, self._info, self._policy)
|
||||
cur_step += 1
|
||||
if self._done:
|
||||
cur_episode += 1
|
||||
|
@ -65,6 +65,17 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
# some code
|
||||
return Batch(logits=..., act=..., state=None, dist=...)
|
||||
|
||||
After version >= 0.2.3, the keyword "policy" is reserverd and the
|
||||
corresponding data will be stored into the replay buffer in numpy. For
|
||||
instance,
|
||||
::
|
||||
|
||||
# some code
|
||||
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
|
||||
# and in the sampled data batch, you can directly call
|
||||
# batch.policy.log_prob to get your data, although it is stored in
|
||||
# np.ndarray.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -92,8 +92,7 @@ class A2CPolicy(PGPolicy):
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
self.optim.zero_grad()
|
||||
result = self(b)
|
||||
dist = result.dist
|
||||
dist = self(b).dist
|
||||
v = self.critic(b.obs)
|
||||
a = torch.tensor(b.act, device=v.device)
|
||||
r = torch.tensor(b.returns, device=v.device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user