add ignore_obs_next in buffer
This commit is contained in:
parent
19f2cce294
commit
13086b7f64
@ -40,7 +40,7 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
if done:
|
||||
obs = env.reset(1)
|
||||
indice = np.arange(len(buf))
|
||||
assert abs(buf.get_stack(indice, 'obs') - np.array([
|
||||
assert abs(buf.get(indice, 'obs') - np.array([
|
||||
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
||||
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6
|
||||
|
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.data import Collector, Batch
|
||||
from tianshou.data import Collector, Batch, ReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -36,21 +36,21 @@ def test_collector():
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
env = env_fns[0]()
|
||||
c0 = Collector(policy, env)
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||
c0.collect(n_step=3)
|
||||
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
||||
assert equal(c0.buffer.obs_next[:3], [1, 2, 1])
|
||||
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
|
||||
c0.collect(n_episode=3)
|
||||
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
assert equal(c0.buffer.obs_next[:8], [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
c1 = Collector(policy, venv)
|
||||
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||
c1.collect(n_step=6)
|
||||
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
||||
assert equal(c1.buffer.obs_next[:11], [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
||||
assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
||||
c1.collect(n_episode=2)
|
||||
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
||||
assert equal(c1.buffer.obs_next[11:21], [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||
c2 = Collector(policy, venv)
|
||||
assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||
c2.collect(n_episode=[1, 2, 2, 2])
|
||||
assert equal(c2.buffer.obs_next[:26], [
|
||||
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||
|
@ -72,7 +72,7 @@ def test_drqn(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(
|
||||
args.buffer_size, stack_num=args.stack_num))
|
||||
args.buffer_size, stack_num=args.stack_num, ignore_obs_next=True))
|
||||
# the stack_num is for RNN training: sample framestack obs
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# policy.set_eps(1)
|
||||
|
@ -41,14 +41,15 @@ class ReplayBuffer(object):
|
||||
>>> batch_data.obs == buf[indice].obs
|
||||
array([ True, True, True, True])
|
||||
|
||||
From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
|
||||
frame_stack sampling, typically for RNN usage:
|
||||
Since version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
|
||||
frame_stack sampling (typically for RNN usage) and ignoring storing the
|
||||
next observation (save memory):
|
||||
::
|
||||
|
||||
>>> buf = ReplayBuffer(size=9, stack_num=4)
|
||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||
>>> for i in range(16):
|
||||
... done = i % 5 == 0
|
||||
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={})
|
||||
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i, info={})
|
||||
>>> print(buf)
|
||||
ReplayBuffer(
|
||||
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
||||
@ -59,7 +60,7 @@ class ReplayBuffer(object):
|
||||
info: [{} {} {} {} {} {} {} {} {}],
|
||||
)
|
||||
>>> index = np.arange(len(buf))
|
||||
>>> print(buf.get_stack(index, 'obs'))
|
||||
>>> print(buf.get(index, 'obs'))
|
||||
[[ 7. 7. 8. 9.]
|
||||
[ 7. 8. 9. 10.]
|
||||
[11. 11. 11. 11.]
|
||||
@ -71,14 +72,15 @@ class ReplayBuffer(object):
|
||||
[ 7. 7. 7. 8.]]
|
||||
>>> # here is another way to get the stacked data
|
||||
>>> # (stack only for obs and obs_next)
|
||||
>>> abs(buf.get_stack(index, 'obs') - buf[index].obs).sum().sum()
|
||||
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
|
||||
0.0
|
||||
"""
|
||||
|
||||
def __init__(self, size, stack_num=0):
|
||||
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._stack = stack_num
|
||||
self._save_s_ = not ignore_obs_next
|
||||
self.reset()
|
||||
|
||||
def __len__(self):
|
||||
@ -125,7 +127,7 @@ class ReplayBuffer(object):
|
||||
while True:
|
||||
self.add(
|
||||
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
|
||||
None if buffer.obs_next is None else buffer.obs_next[i],
|
||||
buffer.obs_next[i] if self._save_s_ else None,
|
||||
buffer.info[i])
|
||||
i = (i + 1) % len(buffer)
|
||||
if i == begin:
|
||||
@ -139,7 +141,8 @@ class ReplayBuffer(object):
|
||||
self._add_to_buffer('act', act)
|
||||
self._add_to_buffer('rew', rew)
|
||||
self._add_to_buffer('done', done)
|
||||
self._add_to_buffer('obs_next', obs_next)
|
||||
if self._save_s_:
|
||||
self._add_to_buffer('obs_next', obs_next)
|
||||
self._add_to_buffer('info', info)
|
||||
if self._maxsize > 0:
|
||||
self._size = min(self._size + 1, self._maxsize)
|
||||
@ -166,19 +169,30 @@ class ReplayBuffer(object):
|
||||
])
|
||||
return self[indice], indice
|
||||
|
||||
def get_stack(self, indice, key):
|
||||
def get(self, indice, key):
|
||||
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
|
||||
where s is self.key, t is indice. The stack_num (here equals to 4) is
|
||||
given from buffer initialization procedure.
|
||||
"""
|
||||
if self.__dict__.get(key, None) is None:
|
||||
return None
|
||||
if self._stack == 0:
|
||||
return self.__dict__[key][indice]
|
||||
stack = []
|
||||
if not isinstance(indice, np.ndarray):
|
||||
if np.isscalar(indice):
|
||||
indice = np.array(indice)
|
||||
elif isinstance(indice, slice):
|
||||
indice = np.arange(
|
||||
0 if indice.start is None else indice.start,
|
||||
self._size if indice.stop is None else indice.stop,
|
||||
1 if indice.step is None else indice.step)
|
||||
# set last frame done to True
|
||||
last_index = (self._index - 1 + self._size) % self._size
|
||||
last_done, self.done[last_index] = self.done[last_index], True
|
||||
if key == 'obs_next' and not self._save_s_:
|
||||
indice += 1 - self.done[indice].astype(np.int)
|
||||
indice[indice == self._size] = 0
|
||||
key = 'obs'
|
||||
if self._stack == 0:
|
||||
self.done[last_index] = last_done
|
||||
return self.__dict__[key][indice]
|
||||
stack = []
|
||||
for i in range(self._stack):
|
||||
stack = [self.__dict__[key][indice]] + stack
|
||||
pre_indice = indice - 1
|
||||
@ -193,11 +207,11 @@ class ReplayBuffer(object):
|
||||
return the stacked obs and obs_next with shape [batch, len, ...].
|
||||
"""
|
||||
return Batch(
|
||||
obs=self.get_stack(index, 'obs'),
|
||||
obs=self.get(index, 'obs'),
|
||||
act=self.act[index],
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.get_stack(index, 'obs_next'),
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.info[index]
|
||||
)
|
||||
|
||||
@ -213,8 +227,8 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
detailed explanation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(size=0)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
||||
|
||||
def _add_to_buffer(self, name, inst):
|
||||
if inst is None:
|
||||
@ -233,8 +247,8 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""docstring for PrioritizedReplayBuffer"""
|
||||
|
||||
def __init__(self, size):
|
||||
super().__init__(size)
|
||||
def __init__(self, size, **kwargs):
|
||||
super().__init__(size, **kwargs)
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
||||
raise NotImplementedError
|
||||
|
@ -22,8 +22,6 @@ class Collector(object):
|
||||
:class:`~tianshou.data.ReplayBuffer`.
|
||||
:param int stat_size: for the moving average of recording speed, defaults
|
||||
to 100.
|
||||
:param bool store_obs_next: store the next observation to replay buffer or
|
||||
not, defaults to ``True``.
|
||||
|
||||
Example:
|
||||
::
|
||||
@ -70,8 +68,7 @@ class Collector(object):
|
||||
Please make sure the given environment has a time limitation.
|
||||
"""
|
||||
|
||||
def __init__(self, policy, env, buffer=None, stat_size=100,
|
||||
store_obs_next=True, **kwargs):
|
||||
def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
@ -106,7 +103,6 @@ class Collector(object):
|
||||
self.state = None
|
||||
self.step_speed = MovAvg(stat_size)
|
||||
self.episode_speed = MovAvg(stat_size)
|
||||
self._save_s_ = store_obs_next
|
||||
|
||||
def reset_buffer(self):
|
||||
"""Reset the main data buffer."""
|
||||
@ -247,8 +243,7 @@ class Collector(object):
|
||||
data = {
|
||||
'obs': self._obs[i], 'act': self._act[i],
|
||||
'rew': self._rew[i], 'done': self._done[i],
|
||||
'obs_next': obs_next[i] if self._save_s_ else None,
|
||||
'info': self._info[i]}
|
||||
'obs_next': obs_next[i], 'info': self._info[i]}
|
||||
if self._cached_buf:
|
||||
warning_count += 1
|
||||
self._cached_buf[i].add(**data)
|
||||
@ -284,8 +279,7 @@ class Collector(object):
|
||||
else:
|
||||
self.buffer.add(
|
||||
self._obs, self._act[0], self._rew,
|
||||
self._done, obs_next if self._save_s_ else None,
|
||||
self._info)
|
||||
self._done, obs_next, self._info)
|
||||
cur_step += 1
|
||||
if self._done:
|
||||
cur_episode += 1
|
||||
|
@ -81,16 +81,17 @@ class DQNPolicy(BasePolicy):
|
||||
returns[buffer.done[now] > 0] = 0
|
||||
returns = buffer.rew[now] + self._gamma * returns
|
||||
terminal = (indice + self._n_step - 1) % len(buffer)
|
||||
terminal_data = buffer[terminal]
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(buffer[terminal], input='obs_next', eps=0).act
|
||||
a = self(terminal_data, input='obs_next', eps=0).act
|
||||
target_q = self(
|
||||
buffer[terminal], model='model_old', input='obs_next').logits
|
||||
terminal_data, model='model_old', input='obs_next').logits
|
||||
if isinstance(target_q, torch.Tensor):
|
||||
target_q = target_q.detach().cpu().numpy()
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
target_q = self(buffer[terminal], input='obs_next').logits
|
||||
target_q = self(terminal_data, input='obs_next').logits
|
||||
if isinstance(target_q, torch.Tensor):
|
||||
target_q = target_q.detach().cpu().numpy()
|
||||
target_q = target_q.max(axis=1)
|
||||
|
@ -40,8 +40,6 @@ class PGPolicy(BasePolicy):
|
||||
discount factor, :math:`\gamma \in [0, 1]`.
|
||||
"""
|
||||
batch.returns = self._vanilla_returns(batch)
|
||||
if getattr(batch, 'obs_next', None) is None:
|
||||
batch.obs_next = buffer[(indice + 1) % len(buffer)].obs
|
||||
# batch.returns = self._vectorized_returns(batch)
|
||||
return batch
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user