add ignore_obs_next in buffer

This commit is contained in:
Trinkle23897 2020-04-10 09:01:17 +08:00
parent 19f2cce294
commit 13086b7f64
7 changed files with 52 additions and 45 deletions

View File

@ -40,7 +40,7 @@ def test_stack(size=5, bufsize=9, stack_num=4):
if done: if done:
obs = env.reset(1) obs = env.reset(1)
indice = np.arange(len(buf)) indice = np.arange(len(buf))
assert abs(buf.get_stack(indice, 'obs') - np.array([ assert abs(buf.get(indice, 'obs') - np.array([
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6 [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6

View File

@ -1,7 +1,7 @@
import numpy as np import numpy as np
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, Batch from tianshou.data import Collector, Batch, ReplayBuffer
if __name__ == '__main__': if __name__ == '__main__':
from env import MyTestEnv from env import MyTestEnv
@ -36,21 +36,21 @@ def test_collector():
venv = SubprocVectorEnv(env_fns) venv = SubprocVectorEnv(env_fns)
policy = MyPolicy() policy = MyPolicy()
env = env_fns[0]() env = env_fns[0]()
c0 = Collector(policy, env) c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
c0.collect(n_step=3) c0.collect(n_step=3)
assert equal(c0.buffer.obs[:3], [0, 1, 0]) assert equal(c0.buffer.obs[:3], [0, 1, 0])
assert equal(c0.buffer.obs_next[:3], [1, 2, 1]) assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
c0.collect(n_episode=3) c0.collect(n_episode=3)
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
assert equal(c0.buffer.obs_next[:8], [1, 2, 1, 2, 1, 2, 1, 2]) assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
c1 = Collector(policy, venv) c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
c1.collect(n_step=6) c1.collect(n_step=6)
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
assert equal(c1.buffer.obs_next[:11], [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
c1.collect(n_episode=2) c1.collect(n_episode=2)
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
assert equal(c1.buffer.obs_next[11:21], [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
c2 = Collector(policy, venv) c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
c2.collect(n_episode=[1, 2, 2, 2]) c2.collect(n_episode=[1, 2, 2, 2])
assert equal(c2.buffer.obs_next[:26], [ assert equal(c2.buffer.obs_next[:26], [
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,

View File

@ -72,7 +72,7 @@ def test_drqn(args=get_args()):
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer( policy, train_envs, ReplayBuffer(
args.buffer_size, stack_num=args.stack_num)) args.buffer_size, stack_num=args.stack_num, ignore_obs_next=True))
# the stack_num is for RNN training: sample framestack obs # the stack_num is for RNN training: sample framestack obs
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# policy.set_eps(1) # policy.set_eps(1)

View File

@ -41,14 +41,15 @@ class ReplayBuffer(object):
>>> batch_data.obs == buf[indice].obs >>> batch_data.obs == buf[indice].obs
array([ True, True, True, True]) array([ True, True, True, True])
From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports Since version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
frame_stack sampling, typically for RNN usage: frame_stack sampling (typically for RNN usage) and ignoring storing the
next observation (save memory):
:: ::
>>> buf = ReplayBuffer(size=9, stack_num=4) >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16): >>> for i in range(16):
... done = i % 5 == 0 ... done = i % 5 == 0
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={}) ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i, info={})
>>> print(buf) >>> print(buf)
ReplayBuffer( ReplayBuffer(
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
@ -59,7 +60,7 @@ class ReplayBuffer(object):
info: [{} {} {} {} {} {} {} {} {}], info: [{} {} {} {} {} {} {} {} {}],
) )
>>> index = np.arange(len(buf)) >>> index = np.arange(len(buf))
>>> print(buf.get_stack(index, 'obs')) >>> print(buf.get(index, 'obs'))
[[ 7. 7. 8. 9.] [[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.] [ 7. 8. 9. 10.]
[11. 11. 11. 11.] [11. 11. 11. 11.]
@ -71,14 +72,15 @@ class ReplayBuffer(object):
[ 7. 7. 7. 8.]] [ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data >>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next) >>> # (stack only for obs and obs_next)
>>> abs(buf.get_stack(index, 'obs') - buf[index].obs).sum().sum() >>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
0.0 0.0
""" """
def __init__(self, size, stack_num=0): def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):
super().__init__() super().__init__()
self._maxsize = size self._maxsize = size
self._stack = stack_num self._stack = stack_num
self._save_s_ = not ignore_obs_next
self.reset() self.reset()
def __len__(self): def __len__(self):
@ -125,7 +127,7 @@ class ReplayBuffer(object):
while True: while True:
self.add( self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
None if buffer.obs_next is None else buffer.obs_next[i], buffer.obs_next[i] if self._save_s_ else None,
buffer.info[i]) buffer.info[i])
i = (i + 1) % len(buffer) i = (i + 1) % len(buffer)
if i == begin: if i == begin:
@ -139,7 +141,8 @@ class ReplayBuffer(object):
self._add_to_buffer('act', act) self._add_to_buffer('act', act)
self._add_to_buffer('rew', rew) self._add_to_buffer('rew', rew)
self._add_to_buffer('done', done) self._add_to_buffer('done', done)
self._add_to_buffer('obs_next', obs_next) if self._save_s_:
self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info) self._add_to_buffer('info', info)
if self._maxsize > 0: if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize) self._size = min(self._size + 1, self._maxsize)
@ -166,19 +169,30 @@ class ReplayBuffer(object):
]) ])
return self[indice], indice return self[indice], indice
def get_stack(self, indice, key): def get(self, indice, key):
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
where s is self.key, t is indice. The stack_num (here equals to 4) is where s is self.key, t is indice. The stack_num (here equals to 4) is
given from buffer initialization procedure. given from buffer initialization procedure.
""" """
if self.__dict__.get(key, None) is None: if not isinstance(indice, np.ndarray):
return None if np.isscalar(indice):
if self._stack == 0: indice = np.array(indice)
return self.__dict__[key][indice] elif isinstance(indice, slice):
stack = [] indice = np.arange(
0 if indice.start is None else indice.start,
self._size if indice.stop is None else indice.stop,
1 if indice.step is None else indice.step)
# set last frame done to True # set last frame done to True
last_index = (self._index - 1 + self._size) % self._size last_index = (self._index - 1 + self._size) % self._size
last_done, self.done[last_index] = self.done[last_index], True last_done, self.done[last_index] = self.done[last_index], True
if key == 'obs_next' and not self._save_s_:
indice += 1 - self.done[indice].astype(np.int)
indice[indice == self._size] = 0
key = 'obs'
if self._stack == 0:
self.done[last_index] = last_done
return self.__dict__[key][indice]
stack = []
for i in range(self._stack): for i in range(self._stack):
stack = [self.__dict__[key][indice]] + stack stack = [self.__dict__[key][indice]] + stack
pre_indice = indice - 1 pre_indice = indice - 1
@ -193,11 +207,11 @@ class ReplayBuffer(object):
return the stacked obs and obs_next with shape [batch, len, ...]. return the stacked obs and obs_next with shape [batch, len, ...].
""" """
return Batch( return Batch(
obs=self.get_stack(index, 'obs'), obs=self.get(index, 'obs'),
act=self.act[index], act=self.act[index],
rew=self.rew[index], rew=self.rew[index],
done=self.done[index], done=self.done[index],
obs_next=self.get_stack(index, 'obs_next'), obs_next=self.get(index, 'obs_next'),
info=self.info[index] info=self.info[index]
) )
@ -213,8 +227,8 @@ class ListReplayBuffer(ReplayBuffer):
detailed explanation. detailed explanation.
""" """
def __init__(self): def __init__(self, **kwargs):
super().__init__(size=0) super().__init__(size=0, ignore_obs_next=False, **kwargs)
def _add_to_buffer(self, name, inst): def _add_to_buffer(self, name, inst):
if inst is None: if inst is None:
@ -233,8 +247,8 @@ class ListReplayBuffer(ReplayBuffer):
class PrioritizedReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer""" """docstring for PrioritizedReplayBuffer"""
def __init__(self, size): def __init__(self, size, **kwargs):
super().__init__(size) super().__init__(size, **kwargs)
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
raise NotImplementedError raise NotImplementedError

View File

@ -22,8 +22,6 @@ class Collector(object):
:class:`~tianshou.data.ReplayBuffer`. :class:`~tianshou.data.ReplayBuffer`.
:param int stat_size: for the moving average of recording speed, defaults :param int stat_size: for the moving average of recording speed, defaults
to 100. to 100.
:param bool store_obs_next: store the next observation to replay buffer or
not, defaults to ``True``.
Example: Example:
:: ::
@ -70,8 +68,7 @@ class Collector(object):
Please make sure the given environment has a time limitation. Please make sure the given environment has a time limitation.
""" """
def __init__(self, policy, env, buffer=None, stat_size=100, def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs):
store_obs_next=True, **kwargs):
super().__init__() super().__init__()
self.env = env self.env = env
self.env_num = 1 self.env_num = 1
@ -106,7 +103,6 @@ class Collector(object):
self.state = None self.state = None
self.step_speed = MovAvg(stat_size) self.step_speed = MovAvg(stat_size)
self.episode_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size)
self._save_s_ = store_obs_next
def reset_buffer(self): def reset_buffer(self):
"""Reset the main data buffer.""" """Reset the main data buffer."""
@ -247,8 +243,7 @@ class Collector(object):
data = { data = {
'obs': self._obs[i], 'act': self._act[i], 'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i], 'rew': self._rew[i], 'done': self._done[i],
'obs_next': obs_next[i] if self._save_s_ else None, 'obs_next': obs_next[i], 'info': self._info[i]}
'info': self._info[i]}
if self._cached_buf: if self._cached_buf:
warning_count += 1 warning_count += 1
self._cached_buf[i].add(**data) self._cached_buf[i].add(**data)
@ -284,8 +279,7 @@ class Collector(object):
else: else:
self.buffer.add( self.buffer.add(
self._obs, self._act[0], self._rew, self._obs, self._act[0], self._rew,
self._done, obs_next if self._save_s_ else None, self._done, obs_next, self._info)
self._info)
cur_step += 1 cur_step += 1
if self._done: if self._done:
cur_episode += 1 cur_episode += 1

View File

@ -81,16 +81,17 @@ class DQNPolicy(BasePolicy):
returns[buffer.done[now] > 0] = 0 returns[buffer.done[now] > 0] = 0
returns = buffer.rew[now] + self._gamma * returns returns = buffer.rew[now] + self._gamma * returns
terminal = (indice + self._n_step - 1) % len(buffer) terminal = (indice + self._n_step - 1) % len(buffer)
terminal_data = buffer[terminal]
if self._target: if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *))) # target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(buffer[terminal], input='obs_next', eps=0).act a = self(terminal_data, input='obs_next', eps=0).act
target_q = self( target_q = self(
buffer[terminal], model='model_old', input='obs_next').logits terminal_data, model='model_old', input='obs_next').logits
if isinstance(target_q, torch.Tensor): if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy() target_q = target_q.detach().cpu().numpy()
target_q = target_q[np.arange(len(a)), a] target_q = target_q[np.arange(len(a)), a]
else: else:
target_q = self(buffer[terminal], input='obs_next').logits target_q = self(terminal_data, input='obs_next').logits
if isinstance(target_q, torch.Tensor): if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy() target_q = target_q.detach().cpu().numpy()
target_q = target_q.max(axis=1) target_q = target_q.max(axis=1)

View File

@ -40,8 +40,6 @@ class PGPolicy(BasePolicy):
discount factor, :math:`\gamma \in [0, 1]`. discount factor, :math:`\gamma \in [0, 1]`.
""" """
batch.returns = self._vanilla_returns(batch) batch.returns = self._vanilla_returns(batch)
if getattr(batch, 'obs_next', None) is None:
batch.obs_next = buffer[(indice + 1) % len(buffer)].obs
# batch.returns = self._vectorized_returns(batch) # batch.returns = self._vectorized_returns(batch)
return batch return batch