fix rnn (#19), add __repr__, and fix #26

This commit is contained in:
Trinkle23897 2020-04-09 19:53:45 +08:00
parent 86572c66d4
commit 6da80e045a
10 changed files with 120 additions and 24 deletions

View File

@ -11,7 +11,6 @@ jobs:
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}

View File

@ -8,9 +8,9 @@ class MyTestEnv(gym.Env):
self.sleep = sleep
self.reset()
def reset(self):
def reset(self, state=0):
self.done = False
self.index = 0
self.index = state
return self.index
def step(self, action):

View File

@ -17,6 +17,7 @@ def test_batch():
batch.obs = np.arange(5)
for i, b in enumerate(batch.split(1, permute=False)):
assert b.obs == batch[i].obs
print(batch)
if __name__ == '__main__':

View File

@ -1,3 +1,4 @@
import numpy as np
from tianshou.data import ReplayBuffer
if __name__ == '__main__':
@ -28,5 +29,24 @@ def test_replaybuffer(size=10, bufsize=20):
assert buf2[-1].obs == buf[4].obs
def test_stack(size=5, bufsize=9, stack_num=4):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num)
obs = env.reset(1)
for i in range(15):
obs_next, rew, done, info = env.step(1)
buf.add(obs, 1, rew, done, None, info)
obs = obs_next
if done:
obs = env.reset(1)
indice = np.arange(len(buf))
assert abs(buf.get_stack(indice, 'obs') - np.array([
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6
print(buf)
if __name__ == '__main__':
test_replaybuffer()
test_stack()

View File

@ -26,7 +26,7 @@ def get_args():
parser.add_argument('--stack-num', type=int, default=4)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--n-step', type=int, default=4)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000)

View File

@ -31,27 +31,29 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
def test_fn(size=2560):
policy = PGPolicy(None, None, None, discount_factor=0.1)
buf = ReplayBuffer(100)
buf.add(1, 1, 1, 1, 1)
fn = policy.process_fn
# fn = compute_return_base
batch = Batch(
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
)
batch = fn(batch, None, None)
batch = fn(batch, buf, 0)
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
assert abs(batch.returns - ans).sum() <= 1e-5
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
)
batch = fn(batch, None, None)
batch = fn(batch, buf, 0)
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
assert abs(batch.returns - ans).sum() <= 1e-5
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
)
batch = fn(batch, None, None)
batch = fn(batch, buf, 0)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
assert abs(batch.returns - ans).sum() <= 1e-5
if __name__ == '__main__':
@ -66,7 +68,7 @@ def test_fn(size=2560):
print(f'vanilla: {(time.time() - t) / cnt}')
t = time.time()
for _ in range(cnt):
policy.process_fn(batch, None, None)
policy.process_fn(batch, buf, 0)
print(f'policy: {(time.time() - t) / cnt}')
@ -147,5 +149,5 @@ def test_pg(args=get_args()):
if __name__ == '__main__':
# test_fn()
test_fn()
test_pg()

View File

@ -73,6 +73,22 @@ class Batch(object):
b.__dict__.update(**{k: self.__dict__[k][index]})
return b
def __repr__(self):
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in self.__dict__.keys():
if k[0] != '_' and self.__dict__[k] is not None:
rpl = '\n' + ' ' * (6 + len(k))
obj = str(self.__dict__[k]).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')\n'
else:
s = self.__class__.__name__ + '()\n'
return s
def append(self, batch):
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'

View File

@ -39,6 +39,34 @@ class ReplayBuffer(object):
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
frame_stack sampling, typically for RNN usage:
::
>>> buf = ReplayBuffer(size=9, stack_num=4)
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={})
>>> print(buf.obs)
[ 9. 10. 11. 12. 13. 14. 15. 7. 8.]
>>> print(buf.done)
[0. 1. 0. 0. 0. 0. 1. 0. 0.]
>>> index = np.arange(len(buf))
>>> print(buf.get_stack(index, 'obs'))
[[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.]
[11. 11. 11. 11.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[ 7. 7. 7. 7.]
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs))
0.0
"""
def __init__(self, size, stack_num=0):
@ -51,8 +79,26 @@ class ReplayBuffer(object):
"""Return len(self)."""
return self._size
def __repr__(self):
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in self.__dict__.keys():
if k[0] != '_' and self.__dict__[k] is not None:
rpl = '\n' + ' ' * (6 + len(k))
obj = str(self.__dict__[k]).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')\n'
else:
s = self.__class__.__name__ + '()\n'
return s
def _add_to_buffer(self, name, inst):
if inst is None:
if getattr(self, name, None) is None:
self.__dict__[name] = None
return
if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray):
@ -72,13 +118,14 @@ class ReplayBuffer(object):
i = begin = buffer._index % len(buffer)
while True:
self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i],
buffer.done[i], buffer.obs_next[i], buffer.info[i])
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
None if buffer.obs_next is None else buffer.obs_next[i],
buffer.info[i])
i = (i + 1) % len(buffer)
if i == begin:
break
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None):
"""Add a batch of data into replay buffer."""
assert isinstance(info, dict), \
'You should return a dict in the last argument of env.step().'
@ -97,7 +144,6 @@ class ReplayBuffer(object):
def reset(self):
"""Clear all the data in replay buffer."""
self._index = self._size = 0
self.indice = []
def sample(self, batch_size):
"""Get a random sample from buffer with size equal to batch_size. \
@ -114,16 +160,26 @@ class ReplayBuffer(object):
])
return self[indice], indice
def _get_stack(self, indice, key):
def get_stack(self, indice, key):
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
where s is self.key, t is indice. The stack_num (here equals to 4) is
given from buffer initialization procedure.
"""
if self.__dict__.get(key, None) is None:
return None
if self._stack == 0:
return self.__dict__[key][indice]
stack = []
# set last frame done to True
last_index = (self._index - 1 + self._size) % self._size
last_done, self.done[last_index] = self.done[last_index], True
for i in range(self._stack):
stack = [self.__dict__[key][indice]] + stack
indice = indice - 1 + self.done[indice - 1].astype(np.int)
indice[indice == -1] = self._size - 1
pre_indice = indice - 1
pre_indice[pre_indice == -1] = self._size - 1
indice = pre_indice + self.done[pre_indice].astype(np.int)
indice[indice == self._size] = 0
self.done[last_index] = last_done
return np.stack(stack, axis=1)
def __getitem__(self, index):
@ -131,11 +187,11 @@ class ReplayBuffer(object):
return the stacked obs and obs_next with shape [batch, len, ...].
"""
return Batch(
obs=self._get_stack(index, 'obs'),
obs=self.get_stack(index, 'obs'),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self._get_stack(index, 'obs_next'),
obs_next=self.get_stack(index, 'obs_next'),
info=self.info[index]
)

View File

@ -2,10 +2,10 @@ import time
import torch
import warnings
import numpy as np
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer, \
ListReplayBuffer
from tianshou.utils import MovAvg
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer
class Collector(object):
@ -22,8 +22,8 @@ class Collector(object):
:class:`~tianshou.data.ReplayBuffer`.
:param int stat_size: for the moving average of recording speed, defaults
to 100.
:param bool store_obs_next: whether to store the obs_next to replay
buffer, defaults to ``True``.
:param bool store_obs_next: store the next observation to replay buffer or
not, defaults to ``True``.
Example:
::
@ -302,7 +302,7 @@ class Collector(object):
self._obs = obs_next
if self._multi_env:
cur_episode = sum(cur_episode)
duration = time.time() - start_time
duration = max(time.time() - start_time, 1e-9)
self.step_speed.add(cur_step / duration)
self.episode_speed.add(cur_episode / duration)
self.collect_step += cur_step

View File

@ -35,6 +35,8 @@ class PGPolicy(BasePolicy):
discount factor, :math:`\gamma \in [0, 1]`.
"""
batch.returns = self._vanilla_returns(batch)
if getattr(batch, 'obs_next', None) is None:
batch.obs_next = buffer[(indice + 1) % len(buffer)].obs
# batch.returns = self._vectorized_returns(batch)
return batch