parent
86572c66d4
commit
6da80e045a
1
.github/workflows/pytest.yml
vendored
1
.github/workflows/pytest.yml
vendored
@ -11,7 +11,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
@ -8,9 +8,9 @@ class MyTestEnv(gym.Env):
|
||||
self.sleep = sleep
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
def reset(self, state=0):
|
||||
self.done = False
|
||||
self.index = 0
|
||||
self.index = state
|
||||
return self.index
|
||||
|
||||
def step(self, action):
|
||||
|
@ -17,6 +17,7 @@ def test_batch():
|
||||
batch.obs = np.arange(5)
|
||||
for i, b in enumerate(batch.split(1, permute=False)):
|
||||
assert b.obs == batch[i].obs
|
||||
print(batch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
from tianshou.data import ReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -28,5 +29,24 @@ def test_replaybuffer(size=10, bufsize=20):
|
||||
assert buf2[-1].obs == buf[4].obs
|
||||
|
||||
|
||||
def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
env = MyTestEnv(size)
|
||||
buf = ReplayBuffer(bufsize, stack_num)
|
||||
obs = env.reset(1)
|
||||
for i in range(15):
|
||||
obs_next, rew, done, info = env.step(1)
|
||||
buf.add(obs, 1, rew, done, None, info)
|
||||
obs = obs_next
|
||||
if done:
|
||||
obs = env.reset(1)
|
||||
indice = np.arange(len(buf))
|
||||
assert abs(buf.get_stack(indice, 'obs') - np.array([
|
||||
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
||||
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6
|
||||
print(buf)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_replaybuffer()
|
||||
test_stack()
|
||||
|
@ -26,7 +26,7 @@ def get_args():
|
||||
parser.add_argument('--stack-num', type=int, default=4)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--n-step', type=int, default=4)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
|
@ -31,27 +31,29 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
||||
|
||||
def test_fn(size=2560):
|
||||
policy = PGPolicy(None, None, None, discount_factor=0.1)
|
||||
buf = ReplayBuffer(100)
|
||||
buf.add(1, 1, 1, 1, 1)
|
||||
fn = policy.process_fn
|
||||
# fn = compute_return_base
|
||||
batch = Batch(
|
||||
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
|
||||
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
|
||||
)
|
||||
batch = fn(batch, None, None)
|
||||
batch = fn(batch, buf, 0)
|
||||
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
batch = Batch(
|
||||
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||
)
|
||||
batch = fn(batch, None, None)
|
||||
batch = fn(batch, buf, 0)
|
||||
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
batch = Batch(
|
||||
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||
)
|
||||
batch = fn(batch, None, None)
|
||||
batch = fn(batch, buf, 0)
|
||||
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
if __name__ == '__main__':
|
||||
@ -66,7 +68,7 @@ def test_fn(size=2560):
|
||||
print(f'vanilla: {(time.time() - t) / cnt}')
|
||||
t = time.time()
|
||||
for _ in range(cnt):
|
||||
policy.process_fn(batch, None, None)
|
||||
policy.process_fn(batch, buf, 0)
|
||||
print(f'policy: {(time.time() - t) / cnt}')
|
||||
|
||||
|
||||
@ -147,5 +149,5 @@ def test_pg(args=get_args()):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_fn()
|
||||
test_fn()
|
||||
test_pg()
|
||||
|
@ -73,6 +73,22 @@ class Batch(object):
|
||||
b.__dict__.update(**{k: self.__dict__[k][index]})
|
||||
return b
|
||||
|
||||
def __repr__(self):
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
for k in self.__dict__.keys():
|
||||
if k[0] != '_' and self.__dict__[k] is not None:
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = str(self.__dict__[k]).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
flag = True
|
||||
if flag:
|
||||
s += ')\n'
|
||||
else:
|
||||
s = self.__class__.__name__ + '()\n'
|
||||
return s
|
||||
|
||||
def append(self, batch):
|
||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||
|
@ -39,6 +39,34 @@ class ReplayBuffer(object):
|
||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||
>>> batch_data.obs == buf[indice].obs
|
||||
array([ True, True, True, True])
|
||||
|
||||
From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
|
||||
frame_stack sampling, typically for RNN usage:
|
||||
::
|
||||
|
||||
>>> buf = ReplayBuffer(size=9, stack_num=4)
|
||||
>>> for i in range(16):
|
||||
... done = i % 5 == 0
|
||||
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={})
|
||||
>>> print(buf.obs)
|
||||
[ 9. 10. 11. 12. 13. 14. 15. 7. 8.]
|
||||
>>> print(buf.done)
|
||||
[0. 1. 0. 0. 0. 0. 1. 0. 0.]
|
||||
>>> index = np.arange(len(buf))
|
||||
>>> print(buf.get_stack(index, 'obs'))
|
||||
[[ 7. 7. 8. 9.]
|
||||
[ 7. 8. 9. 10.]
|
||||
[11. 11. 11. 11.]
|
||||
[11. 11. 11. 12.]
|
||||
[11. 11. 12. 13.]
|
||||
[11. 12. 13. 14.]
|
||||
[12. 13. 14. 15.]
|
||||
[ 7. 7. 7. 7.]
|
||||
[ 7. 7. 7. 8.]]
|
||||
>>> # here is another way to get the stacked data
|
||||
>>> # (stack only for obs and obs_next)
|
||||
>>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs))
|
||||
0.0
|
||||
"""
|
||||
|
||||
def __init__(self, size, stack_num=0):
|
||||
@ -51,8 +79,26 @@ class ReplayBuffer(object):
|
||||
"""Return len(self)."""
|
||||
return self._size
|
||||
|
||||
def __repr__(self):
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
for k in self.__dict__.keys():
|
||||
if k[0] != '_' and self.__dict__[k] is not None:
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = str(self.__dict__[k]).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
flag = True
|
||||
if flag:
|
||||
s += ')\n'
|
||||
else:
|
||||
s = self.__class__.__name__ + '()\n'
|
||||
return s
|
||||
|
||||
def _add_to_buffer(self, name, inst):
|
||||
if inst is None:
|
||||
if getattr(self, name, None) is None:
|
||||
self.__dict__[name] = None
|
||||
return
|
||||
if self.__dict__.get(name, None) is None:
|
||||
if isinstance(inst, np.ndarray):
|
||||
@ -72,13 +118,14 @@ class ReplayBuffer(object):
|
||||
i = begin = buffer._index % len(buffer)
|
||||
while True:
|
||||
self.add(
|
||||
buffer.obs[i], buffer.act[i], buffer.rew[i],
|
||||
buffer.done[i], buffer.obs_next[i], buffer.info[i])
|
||||
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
|
||||
None if buffer.obs_next is None else buffer.obs_next[i],
|
||||
buffer.info[i])
|
||||
i = (i + 1) % len(buffer)
|
||||
if i == begin:
|
||||
break
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
||||
def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None):
|
||||
"""Add a batch of data into replay buffer."""
|
||||
assert isinstance(info, dict), \
|
||||
'You should return a dict in the last argument of env.step().'
|
||||
@ -97,7 +144,6 @@ class ReplayBuffer(object):
|
||||
def reset(self):
|
||||
"""Clear all the data in replay buffer."""
|
||||
self._index = self._size = 0
|
||||
self.indice = []
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""Get a random sample from buffer with size equal to batch_size. \
|
||||
@ -114,16 +160,26 @@ class ReplayBuffer(object):
|
||||
])
|
||||
return self[indice], indice
|
||||
|
||||
def _get_stack(self, indice, key):
|
||||
def get_stack(self, indice, key):
|
||||
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
|
||||
where s is self.key, t is indice. The stack_num (here equals to 4) is
|
||||
given from buffer initialization procedure.
|
||||
"""
|
||||
if self.__dict__.get(key, None) is None:
|
||||
return None
|
||||
if self._stack == 0:
|
||||
return self.__dict__[key][indice]
|
||||
stack = []
|
||||
# set last frame done to True
|
||||
last_index = (self._index - 1 + self._size) % self._size
|
||||
last_done, self.done[last_index] = self.done[last_index], True
|
||||
for i in range(self._stack):
|
||||
stack = [self.__dict__[key][indice]] + stack
|
||||
indice = indice - 1 + self.done[indice - 1].astype(np.int)
|
||||
indice[indice == -1] = self._size - 1
|
||||
pre_indice = indice - 1
|
||||
pre_indice[pre_indice == -1] = self._size - 1
|
||||
indice = pre_indice + self.done[pre_indice].astype(np.int)
|
||||
indice[indice == self._size] = 0
|
||||
self.done[last_index] = last_done
|
||||
return np.stack(stack, axis=1)
|
||||
|
||||
def __getitem__(self, index):
|
||||
@ -131,11 +187,11 @@ class ReplayBuffer(object):
|
||||
return the stacked obs and obs_next with shape [batch, len, ...].
|
||||
"""
|
||||
return Batch(
|
||||
obs=self._get_stack(index, 'obs'),
|
||||
obs=self.get_stack(index, 'obs'),
|
||||
act=self.act[index],
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self._get_stack(index, 'obs_next'),
|
||||
obs_next=self.get_stack(index, 'obs_next'),
|
||||
info=self.info[index]
|
||||
)
|
||||
|
||||
|
@ -2,10 +2,10 @@ import time
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.data import Batch, ReplayBuffer, \
|
||||
ListReplayBuffer
|
||||
|
||||
from tianshou.utils import MovAvg
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer
|
||||
|
||||
|
||||
class Collector(object):
|
||||
@ -22,8 +22,8 @@ class Collector(object):
|
||||
:class:`~tianshou.data.ReplayBuffer`.
|
||||
:param int stat_size: for the moving average of recording speed, defaults
|
||||
to 100.
|
||||
:param bool store_obs_next: whether to store the obs_next to replay
|
||||
buffer, defaults to ``True``.
|
||||
:param bool store_obs_next: store the next observation to replay buffer or
|
||||
not, defaults to ``True``.
|
||||
|
||||
Example:
|
||||
::
|
||||
@ -302,7 +302,7 @@ class Collector(object):
|
||||
self._obs = obs_next
|
||||
if self._multi_env:
|
||||
cur_episode = sum(cur_episode)
|
||||
duration = time.time() - start_time
|
||||
duration = max(time.time() - start_time, 1e-9)
|
||||
self.step_speed.add(cur_step / duration)
|
||||
self.episode_speed.add(cur_episode / duration)
|
||||
self.collect_step += cur_step
|
||||
|
@ -35,6 +35,8 @@ class PGPolicy(BasePolicy):
|
||||
discount factor, :math:`\gamma \in [0, 1]`.
|
||||
"""
|
||||
batch.returns = self._vanilla_returns(batch)
|
||||
if getattr(batch, 'obs_next', None) is None:
|
||||
batch.obs_next = buffer[(indice + 1) % len(buffer)].obs
|
||||
# batch.returns = self._vectorized_returns(batch)
|
||||
return batch
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user