add ignore_obs_next in buffer

This commit is contained in:
Trinkle23897 2020-04-10 09:01:17 +08:00
parent 19f2cce294
commit 13086b7f64
7 changed files with 52 additions and 45 deletions

View File

@ -40,7 +40,7 @@ def test_stack(size=5, bufsize=9, stack_num=4):
if done:
obs = env.reset(1)
indice = np.arange(len(buf))
assert abs(buf.get_stack(indice, 'obs') - np.array([
assert abs(buf.get(indice, 'obs') - np.array([
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6

View File

@ -1,7 +1,7 @@
import numpy as np
from tianshou.policy import BasePolicy
from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, Batch
from tianshou.data import Collector, Batch, ReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
@ -36,21 +36,21 @@ def test_collector():
venv = SubprocVectorEnv(env_fns)
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(policy, env)
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
c0.collect(n_step=3)
assert equal(c0.buffer.obs[:3], [0, 1, 0])
assert equal(c0.buffer.obs_next[:3], [1, 2, 1])
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
c0.collect(n_episode=3)
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
assert equal(c0.buffer.obs_next[:8], [1, 2, 1, 2, 1, 2, 1, 2])
c1 = Collector(policy, venv)
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
c1.collect(n_step=6)
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
assert equal(c1.buffer.obs_next[:11], [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
c1.collect(n_episode=2)
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
assert equal(c1.buffer.obs_next[11:21], [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
c2 = Collector(policy, venv)
assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
c2.collect(n_episode=[1, 2, 2, 2])
assert equal(c2.buffer.obs_next[:26], [
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,

View File

@ -72,7 +72,7 @@ def test_drqn(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(
args.buffer_size, stack_num=args.stack_num))
args.buffer_size, stack_num=args.stack_num, ignore_obs_next=True))
# the stack_num is for RNN training: sample framestack obs
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)

View File

@ -41,14 +41,15 @@ class ReplayBuffer(object):
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
frame_stack sampling, typically for RNN usage:
Since version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
frame_stack sampling (typically for RNN usage) and ignoring storing the
next observation (save memory):
::
>>> buf = ReplayBuffer(size=9, stack_num=4)
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={})
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i, info={})
>>> print(buf)
ReplayBuffer(
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
@ -59,7 +60,7 @@ class ReplayBuffer(object):
info: [{} {} {} {} {} {} {} {} {}],
)
>>> index = np.arange(len(buf))
>>> print(buf.get_stack(index, 'obs'))
>>> print(buf.get(index, 'obs'))
[[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.]
[11. 11. 11. 11.]
@ -71,14 +72,15 @@ class ReplayBuffer(object):
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get_stack(index, 'obs') - buf[index].obs).sum().sum()
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
0.0
"""
def __init__(self, size, stack_num=0):
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):
super().__init__()
self._maxsize = size
self._stack = stack_num
self._save_s_ = not ignore_obs_next
self.reset()
def __len__(self):
@ -125,7 +127,7 @@ class ReplayBuffer(object):
while True:
self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
None if buffer.obs_next is None else buffer.obs_next[i],
buffer.obs_next[i] if self._save_s_ else None,
buffer.info[i])
i = (i + 1) % len(buffer)
if i == begin:
@ -139,7 +141,8 @@ class ReplayBuffer(object):
self._add_to_buffer('act', act)
self._add_to_buffer('rew', rew)
self._add_to_buffer('done', done)
self._add_to_buffer('obs_next', obs_next)
if self._save_s_:
self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info)
if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize)
@ -166,19 +169,30 @@ class ReplayBuffer(object):
])
return self[indice], indice
def get_stack(self, indice, key):
def get(self, indice, key):
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
where s is self.key, t is indice. The stack_num (here equals to 4) is
given from buffer initialization procedure.
"""
if self.__dict__.get(key, None) is None:
return None
if self._stack == 0:
return self.__dict__[key][indice]
stack = []
if not isinstance(indice, np.ndarray):
if np.isscalar(indice):
indice = np.array(indice)
elif isinstance(indice, slice):
indice = np.arange(
0 if indice.start is None else indice.start,
self._size if indice.stop is None else indice.stop,
1 if indice.step is None else indice.step)
# set last frame done to True
last_index = (self._index - 1 + self._size) % self._size
last_done, self.done[last_index] = self.done[last_index], True
if key == 'obs_next' and not self._save_s_:
indice += 1 - self.done[indice].astype(np.int)
indice[indice == self._size] = 0
key = 'obs'
if self._stack == 0:
self.done[last_index] = last_done
return self.__dict__[key][indice]
stack = []
for i in range(self._stack):
stack = [self.__dict__[key][indice]] + stack
pre_indice = indice - 1
@ -193,11 +207,11 @@ class ReplayBuffer(object):
return the stacked obs and obs_next with shape [batch, len, ...].
"""
return Batch(
obs=self.get_stack(index, 'obs'),
obs=self.get(index, 'obs'),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.get_stack(index, 'obs_next'),
obs_next=self.get(index, 'obs_next'),
info=self.info[index]
)
@ -213,8 +227,8 @@ class ListReplayBuffer(ReplayBuffer):
detailed explanation.
"""
def __init__(self):
super().__init__(size=0)
def __init__(self, **kwargs):
super().__init__(size=0, ignore_obs_next=False, **kwargs)
def _add_to_buffer(self, name, inst):
if inst is None:
@ -233,8 +247,8 @@ class ListReplayBuffer(ReplayBuffer):
class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer"""
def __init__(self, size):
super().__init__(size)
def __init__(self, size, **kwargs):
super().__init__(size, **kwargs)
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
raise NotImplementedError

View File

@ -22,8 +22,6 @@ class Collector(object):
:class:`~tianshou.data.ReplayBuffer`.
:param int stat_size: for the moving average of recording speed, defaults
to 100.
:param bool store_obs_next: store the next observation to replay buffer or
not, defaults to ``True``.
Example:
::
@ -70,8 +68,7 @@ class Collector(object):
Please make sure the given environment has a time limitation.
"""
def __init__(self, policy, env, buffer=None, stat_size=100,
store_obs_next=True, **kwargs):
def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs):
super().__init__()
self.env = env
self.env_num = 1
@ -106,7 +103,6 @@ class Collector(object):
self.state = None
self.step_speed = MovAvg(stat_size)
self.episode_speed = MovAvg(stat_size)
self._save_s_ = store_obs_next
def reset_buffer(self):
"""Reset the main data buffer."""
@ -247,8 +243,7 @@ class Collector(object):
data = {
'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i],
'obs_next': obs_next[i] if self._save_s_ else None,
'info': self._info[i]}
'obs_next': obs_next[i], 'info': self._info[i]}
if self._cached_buf:
warning_count += 1
self._cached_buf[i].add(**data)
@ -284,8 +279,7 @@ class Collector(object):
else:
self.buffer.add(
self._obs, self._act[0], self._rew,
self._done, obs_next if self._save_s_ else None,
self._info)
self._done, obs_next, self._info)
cur_step += 1
if self._done:
cur_episode += 1

View File

@ -81,16 +81,17 @@ class DQNPolicy(BasePolicy):
returns[buffer.done[now] > 0] = 0
returns = buffer.rew[now] + self._gamma * returns
terminal = (indice + self._n_step - 1) % len(buffer)
terminal_data = buffer[terminal]
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(buffer[terminal], input='obs_next', eps=0).act
a = self(terminal_data, input='obs_next', eps=0).act
target_q = self(
buffer[terminal], model='model_old', input='obs_next').logits
terminal_data, model='model_old', input='obs_next').logits
if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy()
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(buffer[terminal], input='obs_next').logits
target_q = self(terminal_data, input='obs_next').logits
if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy()
target_q = target_q.max(axis=1)

View File

@ -40,8 +40,6 @@ class PGPolicy(BasePolicy):
discount factor, :math:`\gamma \in [0, 1]`.
"""
batch.returns = self._vanilla_returns(batch)
if getattr(batch, 'obs_next', None) is None:
batch.obs_next = buffer[(indice + 1) % len(buffer)].obs
# batch.returns = self._vectorized_returns(batch)
return batch