reserve 'policy' keyword in replay buffer
This commit is contained in:
parent
e58fc78546
commit
134f787e24
@ -24,7 +24,7 @@ class Batch(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
In short, you can define a :class:`Batch` with any key-value pair. The
|
In short, you can define a :class:`Batch` with any key-value pair. The
|
||||||
current implementation of Tianshou typically use 6 reserved keys in
|
current implementation of Tianshou typically use 7 reserved keys in
|
||||||
:class:`~tianshou.data.Batch`:
|
:class:`~tianshou.data.Batch`:
|
||||||
|
|
||||||
* ``obs`` the observation of step :math:`t` ;
|
* ``obs`` the observation of step :math:`t` ;
|
||||||
@ -34,6 +34,7 @@ class Batch(object):
|
|||||||
* ``obs_next`` the observation of step :math:`t+1` ;
|
* ``obs_next`` the observation of step :math:`t+1` ;
|
||||||
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
|
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
|
||||||
function return 4 arguments, and the last one is ``info``);
|
function return 4 arguments, and the last one is ``info``);
|
||||||
|
* ``policy`` the data computed by policy in step :math:`t`;
|
||||||
|
|
||||||
:class:`~tianshou.data.Batch` has other methods, including
|
:class:`~tianshou.data.Batch` has other methods, including
|
||||||
:meth:`~tianshou.data.Batch.__getitem__`,
|
:meth:`~tianshou.data.Batch.__getitem__`,
|
||||||
@ -129,6 +130,14 @@ class Batch(object):
|
|||||||
return sorted([i for i in self.__dict__ if i[0] != '_'] +
|
return sorted([i for i in self.__dict__ if i[0] != '_'] +
|
||||||
list(self._meta))
|
list(self._meta))
|
||||||
|
|
||||||
|
def to_numpy(self):
|
||||||
|
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
||||||
|
operation.
|
||||||
|
"""
|
||||||
|
for k in self.__dict__:
|
||||||
|
if isinstance(self.__dict__[k], torch.Tensor):
|
||||||
|
self.__dict__[k] = self.__dict__[k].cpu().numpy()
|
||||||
|
|
||||||
def append(self, batch):
|
def append(self, batch):
|
||||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||||
|
@ -42,26 +42,30 @@ class ReplayBuffer(object):
|
|||||||
>>> batch_data.obs == buf[indice].obs
|
>>> batch_data.obs == buf[indice].obs
|
||||||
array([ True, True, True, True])
|
array([ True, True, True, True])
|
||||||
|
|
||||||
Since version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
|
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
||||||
frame_stack sampling (typically for RNN usage) and ignoring storing the
|
(typically for RNN usage, see issue#19), ignoring storing the next
|
||||||
next observation (save memory):
|
observation (save memory in atari tasks), and multi-modal observation (see
|
||||||
|
issue#38, need version >= 0.2.3):
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||||
>>> for i in range(16):
|
>>> for i in range(16):
|
||||||
... done = i % 5 == 0
|
... done = i % 5 == 0
|
||||||
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1)
|
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
|
||||||
>>> print(buf)
|
... obs_next={'id': i + 1})
|
||||||
|
>>> print(buf) # you can see obs_next is not saved in buf
|
||||||
ReplayBuffer(
|
ReplayBuffer(
|
||||||
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||||
act: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
|
||||||
rew: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
info: array([{}, {}, {}, {}, {}, {}, {}, {}, {}], dtype=object),
|
||||||
done: [0. 1. 0. 0. 0. 0. 1. 0. 0.],
|
obs: Batch(
|
||||||
obs_next: [0. 0. 0. 0. 0. 0. 0. 0. 0.],
|
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||||
info: [{} {} {} {} {} {} {} {} {}],
|
),
|
||||||
|
policy: Batch(),
|
||||||
|
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
|
||||||
)
|
)
|
||||||
>>> index = np.arange(len(buf))
|
>>> index = np.arange(len(buf))
|
||||||
>>> print(buf.get(index, 'obs'))
|
>>> print(buf.get(index, 'obs').id)
|
||||||
[[ 7. 7. 8. 9.]
|
[[ 7. 7. 8. 9.]
|
||||||
[ 7. 8. 9. 10.]
|
[ 7. 8. 9. 10.]
|
||||||
[11. 11. 11. 11.]
|
[11. 11. 11. 11.]
|
||||||
@ -73,10 +77,10 @@ class ReplayBuffer(object):
|
|||||||
[ 7. 7. 7. 8.]]
|
[ 7. 7. 7. 8.]]
|
||||||
>>> # here is another way to get the stacked data
|
>>> # here is another way to get the stacked data
|
||||||
>>> # (stack only for obs and obs_next)
|
>>> # (stack only for obs and obs_next)
|
||||||
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
|
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
|
||||||
0.0
|
0.0
|
||||||
>>> # we can get obs_next through __getitem__, even if it doesn't store
|
>>> # we can get obs_next through __getitem__, even if it doesn't exist
|
||||||
>>> print(buf[:].obs_next)
|
>>> print(buf[:].obs_next.id)
|
||||||
[[ 7. 8. 9. 10.]
|
[[ 7. 8. 9. 10.]
|
||||||
[ 7. 8. 9. 10.]
|
[ 7. 8. 9. 10.]
|
||||||
[11. 11. 11. 12.]
|
[11. 11. 11. 12.]
|
||||||
@ -169,7 +173,7 @@ class ReplayBuffer(object):
|
|||||||
self.add(
|
self.add(
|
||||||
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
|
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
|
||||||
buffer.obs_next[i] if self._save_s_ else None,
|
buffer.obs_next[i] if self._save_s_ else None,
|
||||||
buffer.info[i])
|
buffer.info[i], buffer.policy[i])
|
||||||
i = (i + 1) % len(buffer)
|
i = (i + 1) % len(buffer)
|
||||||
if i == begin:
|
if i == begin:
|
||||||
break
|
break
|
||||||
|
@ -175,6 +175,20 @@ class Collector(object):
|
|||||||
isinstance(self.state, np.ndarray):
|
isinstance(self.state, np.ndarray):
|
||||||
self.state[id] = 0
|
self.state[id] = 0
|
||||||
|
|
||||||
|
def _to_numpy(self, x):
|
||||||
|
"""Return an object without torch.Tensor."""
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return x.cpu().numpy()
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
for k in x:
|
||||||
|
if isinstance(x[k], torch.Tensor):
|
||||||
|
x[k] = x[k].cpu().numpy()
|
||||||
|
return x
|
||||||
|
elif isinstance(x, Batch):
|
||||||
|
x.to_numpy()
|
||||||
|
return x
|
||||||
|
return x
|
||||||
|
|
||||||
def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
|
def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
|
||||||
"""Collect a specified number of step or episode.
|
"""Collect a specified number of step or episode.
|
||||||
|
|
||||||
@ -232,8 +246,11 @@ class Collector(object):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = self.policy(batch_data, self.state)
|
result = self.policy(batch_data, self.state)
|
||||||
self.state = result.state if hasattr(result, 'state') else None
|
self.state = result.state if hasattr(result, 'state') else None
|
||||||
|
self._policy = self._to_numpy(result.policy) \
|
||||||
|
if hasattr(result, 'policy') \
|
||||||
|
else [{}] * self.env_num if self._multi_env else {}
|
||||||
if isinstance(result.act, torch.Tensor):
|
if isinstance(result.act, torch.Tensor):
|
||||||
self._act = result.act.detach().cpu().numpy()
|
self._act = self._to_numpy(result.act)
|
||||||
elif not isinstance(self._act, np.ndarray):
|
elif not isinstance(self._act, np.ndarray):
|
||||||
self._act = np.array(result.act)
|
self._act = np.array(result.act)
|
||||||
else:
|
else:
|
||||||
@ -253,7 +270,8 @@ class Collector(object):
|
|||||||
data = {
|
data = {
|
||||||
'obs': self._obs[i], 'act': self._act[i],
|
'obs': self._obs[i], 'act': self._act[i],
|
||||||
'rew': self._rew[i], 'done': self._done[i],
|
'rew': self._rew[i], 'done': self._done[i],
|
||||||
'obs_next': obs_next[i], 'info': self._info[i]}
|
'obs_next': obs_next[i], 'info': self._info[i],
|
||||||
|
'policy': self._policy[i]}
|
||||||
if self._cached_buf:
|
if self._cached_buf:
|
||||||
warning_count += 1
|
warning_count += 1
|
||||||
self._cached_buf[i].add(**data)
|
self._cached_buf[i].add(**data)
|
||||||
@ -292,7 +310,7 @@ class Collector(object):
|
|||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
self.buffer.add(
|
self.buffer.add(
|
||||||
self._obs, self._act[0], self._rew,
|
self._obs, self._act[0], self._rew,
|
||||||
self._done, obs_next, self._info)
|
self._done, obs_next, self._info, self._policy)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
if self._done:
|
if self._done:
|
||||||
cur_episode += 1
|
cur_episode += 1
|
||||||
|
@ -65,6 +65,17 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
|
|
||||||
# some code
|
# some code
|
||||||
return Batch(logits=..., act=..., state=None, dist=...)
|
return Batch(logits=..., act=..., state=None, dist=...)
|
||||||
|
|
||||||
|
After version >= 0.2.3, the keyword "policy" is reserverd and the
|
||||||
|
corresponding data will be stored into the replay buffer in numpy. For
|
||||||
|
instance,
|
||||||
|
::
|
||||||
|
|
||||||
|
# some code
|
||||||
|
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
|
||||||
|
# and in the sampled data batch, you can directly call
|
||||||
|
# batch.policy.log_prob to get your data, although it is stored in
|
||||||
|
# np.ndarray.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -92,8 +92,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
result = self(b)
|
dist = self(b).dist
|
||||||
dist = result.dist
|
|
||||||
v = self.critic(b.obs)
|
v = self.critic(b.obs)
|
||||||
a = torch.tensor(b.act, device=v.device)
|
a = torch.tensor(b.act, device=v.device)
|
||||||
r = torch.tensor(b.returns, device=v.device)
|
r = torch.tensor(b.returns, device=v.device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user