parent
86572c66d4
commit
6da80e045a
1
.github/workflows/pytest.yml
vendored
1
.github/workflows/pytest.yml
vendored
@ -11,7 +11,6 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [3.6, 3.7, 3.8]
|
python-version: [3.6, 3.7, 3.8]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
@ -8,9 +8,9 @@ class MyTestEnv(gym.Env):
|
|||||||
self.sleep = sleep
|
self.sleep = sleep
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, state=0):
|
||||||
self.done = False
|
self.done = False
|
||||||
self.index = 0
|
self.index = state
|
||||||
return self.index
|
return self.index
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
@ -17,6 +17,7 @@ def test_batch():
|
|||||||
batch.obs = np.arange(5)
|
batch.obs = np.arange(5)
|
||||||
for i, b in enumerate(batch.split(1, permute=False)):
|
for i, b in enumerate(batch.split(1, permute=False)):
|
||||||
assert b.obs == batch[i].obs
|
assert b.obs == batch[i].obs
|
||||||
|
print(batch)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import numpy as np
|
||||||
from tianshou.data import ReplayBuffer
|
from tianshou.data import ReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -28,5 +29,24 @@ def test_replaybuffer(size=10, bufsize=20):
|
|||||||
assert buf2[-1].obs == buf[4].obs
|
assert buf2[-1].obs == buf[4].obs
|
||||||
|
|
||||||
|
|
||||||
|
def test_stack(size=5, bufsize=9, stack_num=4):
|
||||||
|
env = MyTestEnv(size)
|
||||||
|
buf = ReplayBuffer(bufsize, stack_num)
|
||||||
|
obs = env.reset(1)
|
||||||
|
for i in range(15):
|
||||||
|
obs_next, rew, done, info = env.step(1)
|
||||||
|
buf.add(obs, 1, rew, done, None, info)
|
||||||
|
obs = obs_next
|
||||||
|
if done:
|
||||||
|
obs = env.reset(1)
|
||||||
|
indice = np.arange(len(buf))
|
||||||
|
assert abs(buf.get_stack(indice, 'obs') - np.array([
|
||||||
|
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
||||||
|
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||||
|
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6
|
||||||
|
print(buf)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_replaybuffer()
|
test_replaybuffer()
|
||||||
|
test_stack()
|
||||||
|
@ -26,7 +26,7 @@ def get_args():
|
|||||||
parser.add_argument('--stack-num', type=int, default=4)
|
parser.add_argument('--stack-num', type=int, default=4)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=4)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
|
@ -31,27 +31,29 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
|||||||
|
|
||||||
def test_fn(size=2560):
|
def test_fn(size=2560):
|
||||||
policy = PGPolicy(None, None, None, discount_factor=0.1)
|
policy = PGPolicy(None, None, None, discount_factor=0.1)
|
||||||
|
buf = ReplayBuffer(100)
|
||||||
|
buf.add(1, 1, 1, 1, 1)
|
||||||
fn = policy.process_fn
|
fn = policy.process_fn
|
||||||
# fn = compute_return_base
|
# fn = compute_return_base
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
|
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
|
||||||
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
|
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
|
||||||
)
|
)
|
||||||
batch = fn(batch, None, None)
|
batch = fn(batch, buf, 0)
|
||||||
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
||||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||||
)
|
)
|
||||||
batch = fn(batch, None, None)
|
batch = fn(batch, buf, 0)
|
||||||
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
||||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||||
)
|
)
|
||||||
batch = fn(batch, None, None)
|
batch = fn(batch, buf, 0)
|
||||||
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -66,7 +68,7 @@ def test_fn(size=2560):
|
|||||||
print(f'vanilla: {(time.time() - t) / cnt}')
|
print(f'vanilla: {(time.time() - t) / cnt}')
|
||||||
t = time.time()
|
t = time.time()
|
||||||
for _ in range(cnt):
|
for _ in range(cnt):
|
||||||
policy.process_fn(batch, None, None)
|
policy.process_fn(batch, buf, 0)
|
||||||
print(f'policy: {(time.time() - t) / cnt}')
|
print(f'policy: {(time.time() - t) / cnt}')
|
||||||
|
|
||||||
|
|
||||||
@ -147,5 +149,5 @@ def test_pg(args=get_args()):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_fn()
|
test_fn()
|
||||||
test_pg()
|
test_pg()
|
||||||
|
@ -73,6 +73,22 @@ class Batch(object):
|
|||||||
b.__dict__.update(**{k: self.__dict__[k][index]})
|
b.__dict__.update(**{k: self.__dict__[k][index]})
|
||||||
return b
|
return b
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""Return str(self)."""
|
||||||
|
s = self.__class__.__name__ + '(\n'
|
||||||
|
flag = False
|
||||||
|
for k in self.__dict__.keys():
|
||||||
|
if k[0] != '_' and self.__dict__[k] is not None:
|
||||||
|
rpl = '\n' + ' ' * (6 + len(k))
|
||||||
|
obj = str(self.__dict__[k]).replace('\n', rpl)
|
||||||
|
s += f' {k}: {obj},\n'
|
||||||
|
flag = True
|
||||||
|
if flag:
|
||||||
|
s += ')\n'
|
||||||
|
else:
|
||||||
|
s = self.__class__.__name__ + '()\n'
|
||||||
|
return s
|
||||||
|
|
||||||
def append(self, batch):
|
def append(self, batch):
|
||||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||||
|
@ -39,6 +39,34 @@ class ReplayBuffer(object):
|
|||||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||||
>>> batch_data.obs == buf[indice].obs
|
>>> batch_data.obs == buf[indice].obs
|
||||||
array([ True, True, True, True])
|
array([ True, True, True, True])
|
||||||
|
|
||||||
|
From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
|
||||||
|
frame_stack sampling, typically for RNN usage:
|
||||||
|
::
|
||||||
|
|
||||||
|
>>> buf = ReplayBuffer(size=9, stack_num=4)
|
||||||
|
>>> for i in range(16):
|
||||||
|
... done = i % 5 == 0
|
||||||
|
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={})
|
||||||
|
>>> print(buf.obs)
|
||||||
|
[ 9. 10. 11. 12. 13. 14. 15. 7. 8.]
|
||||||
|
>>> print(buf.done)
|
||||||
|
[0. 1. 0. 0. 0. 0. 1. 0. 0.]
|
||||||
|
>>> index = np.arange(len(buf))
|
||||||
|
>>> print(buf.get_stack(index, 'obs'))
|
||||||
|
[[ 7. 7. 8. 9.]
|
||||||
|
[ 7. 8. 9. 10.]
|
||||||
|
[11. 11. 11. 11.]
|
||||||
|
[11. 11. 11. 12.]
|
||||||
|
[11. 11. 12. 13.]
|
||||||
|
[11. 12. 13. 14.]
|
||||||
|
[12. 13. 14. 15.]
|
||||||
|
[ 7. 7. 7. 7.]
|
||||||
|
[ 7. 7. 7. 8.]]
|
||||||
|
>>> # here is another way to get the stacked data
|
||||||
|
>>> # (stack only for obs and obs_next)
|
||||||
|
>>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs))
|
||||||
|
0.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size, stack_num=0):
|
def __init__(self, size, stack_num=0):
|
||||||
@ -51,8 +79,26 @@ class ReplayBuffer(object):
|
|||||||
"""Return len(self)."""
|
"""Return len(self)."""
|
||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""Return str(self)."""
|
||||||
|
s = self.__class__.__name__ + '(\n'
|
||||||
|
flag = False
|
||||||
|
for k in self.__dict__.keys():
|
||||||
|
if k[0] != '_' and self.__dict__[k] is not None:
|
||||||
|
rpl = '\n' + ' ' * (6 + len(k))
|
||||||
|
obj = str(self.__dict__[k]).replace('\n', rpl)
|
||||||
|
s += f' {k}: {obj},\n'
|
||||||
|
flag = True
|
||||||
|
if flag:
|
||||||
|
s += ')\n'
|
||||||
|
else:
|
||||||
|
s = self.__class__.__name__ + '()\n'
|
||||||
|
return s
|
||||||
|
|
||||||
def _add_to_buffer(self, name, inst):
|
def _add_to_buffer(self, name, inst):
|
||||||
if inst is None:
|
if inst is None:
|
||||||
|
if getattr(self, name, None) is None:
|
||||||
|
self.__dict__[name] = None
|
||||||
return
|
return
|
||||||
if self.__dict__.get(name, None) is None:
|
if self.__dict__.get(name, None) is None:
|
||||||
if isinstance(inst, np.ndarray):
|
if isinstance(inst, np.ndarray):
|
||||||
@ -72,13 +118,14 @@ class ReplayBuffer(object):
|
|||||||
i = begin = buffer._index % len(buffer)
|
i = begin = buffer._index % len(buffer)
|
||||||
while True:
|
while True:
|
||||||
self.add(
|
self.add(
|
||||||
buffer.obs[i], buffer.act[i], buffer.rew[i],
|
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
|
||||||
buffer.done[i], buffer.obs_next[i], buffer.info[i])
|
None if buffer.obs_next is None else buffer.obs_next[i],
|
||||||
|
buffer.info[i])
|
||||||
i = (i + 1) % len(buffer)
|
i = (i + 1) % len(buffer)
|
||||||
if i == begin:
|
if i == begin:
|
||||||
break
|
break
|
||||||
|
|
||||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None):
|
||||||
"""Add a batch of data into replay buffer."""
|
"""Add a batch of data into replay buffer."""
|
||||||
assert isinstance(info, dict), \
|
assert isinstance(info, dict), \
|
||||||
'You should return a dict in the last argument of env.step().'
|
'You should return a dict in the last argument of env.step().'
|
||||||
@ -97,7 +144,6 @@ class ReplayBuffer(object):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
"""Clear all the data in replay buffer."""
|
"""Clear all the data in replay buffer."""
|
||||||
self._index = self._size = 0
|
self._index = self._size = 0
|
||||||
self.indice = []
|
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
"""Get a random sample from buffer with size equal to batch_size. \
|
"""Get a random sample from buffer with size equal to batch_size. \
|
||||||
@ -114,16 +160,26 @@ class ReplayBuffer(object):
|
|||||||
])
|
])
|
||||||
return self[indice], indice
|
return self[indice], indice
|
||||||
|
|
||||||
def _get_stack(self, indice, key):
|
def get_stack(self, indice, key):
|
||||||
|
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
|
||||||
|
where s is self.key, t is indice. The stack_num (here equals to 4) is
|
||||||
|
given from buffer initialization procedure.
|
||||||
|
"""
|
||||||
if self.__dict__.get(key, None) is None:
|
if self.__dict__.get(key, None) is None:
|
||||||
return None
|
return None
|
||||||
if self._stack == 0:
|
if self._stack == 0:
|
||||||
return self.__dict__[key][indice]
|
return self.__dict__[key][indice]
|
||||||
stack = []
|
stack = []
|
||||||
|
# set last frame done to True
|
||||||
|
last_index = (self._index - 1 + self._size) % self._size
|
||||||
|
last_done, self.done[last_index] = self.done[last_index], True
|
||||||
for i in range(self._stack):
|
for i in range(self._stack):
|
||||||
stack = [self.__dict__[key][indice]] + stack
|
stack = [self.__dict__[key][indice]] + stack
|
||||||
indice = indice - 1 + self.done[indice - 1].astype(np.int)
|
pre_indice = indice - 1
|
||||||
indice[indice == -1] = self._size - 1
|
pre_indice[pre_indice == -1] = self._size - 1
|
||||||
|
indice = pre_indice + self.done[pre_indice].astype(np.int)
|
||||||
|
indice[indice == self._size] = 0
|
||||||
|
self.done[last_index] = last_done
|
||||||
return np.stack(stack, axis=1)
|
return np.stack(stack, axis=1)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
@ -131,11 +187,11 @@ class ReplayBuffer(object):
|
|||||||
return the stacked obs and obs_next with shape [batch, len, ...].
|
return the stacked obs and obs_next with shape [batch, len, ...].
|
||||||
"""
|
"""
|
||||||
return Batch(
|
return Batch(
|
||||||
obs=self._get_stack(index, 'obs'),
|
obs=self.get_stack(index, 'obs'),
|
||||||
act=self.act[index],
|
act=self.act[index],
|
||||||
rew=self.rew[index],
|
rew=self.rew[index],
|
||||||
done=self.done[index],
|
done=self.done[index],
|
||||||
obs_next=self._get_stack(index, 'obs_next'),
|
obs_next=self.get_stack(index, 'obs_next'),
|
||||||
info=self.info[index]
|
info=self.info[index]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,10 +2,10 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tianshou.env import BaseVectorEnv
|
|
||||||
from tianshou.data import Batch, ReplayBuffer, \
|
|
||||||
ListReplayBuffer
|
|
||||||
from tianshou.utils import MovAvg
|
from tianshou.utils import MovAvg
|
||||||
|
from tianshou.env import BaseVectorEnv
|
||||||
|
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
class Collector(object):
|
class Collector(object):
|
||||||
@ -22,8 +22,8 @@ class Collector(object):
|
|||||||
:class:`~tianshou.data.ReplayBuffer`.
|
:class:`~tianshou.data.ReplayBuffer`.
|
||||||
:param int stat_size: for the moving average of recording speed, defaults
|
:param int stat_size: for the moving average of recording speed, defaults
|
||||||
to 100.
|
to 100.
|
||||||
:param bool store_obs_next: whether to store the obs_next to replay
|
:param bool store_obs_next: store the next observation to replay buffer or
|
||||||
buffer, defaults to ``True``.
|
not, defaults to ``True``.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
::
|
::
|
||||||
@ -302,7 +302,7 @@ class Collector(object):
|
|||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
cur_episode = sum(cur_episode)
|
cur_episode = sum(cur_episode)
|
||||||
duration = time.time() - start_time
|
duration = max(time.time() - start_time, 1e-9)
|
||||||
self.step_speed.add(cur_step / duration)
|
self.step_speed.add(cur_step / duration)
|
||||||
self.episode_speed.add(cur_episode / duration)
|
self.episode_speed.add(cur_episode / duration)
|
||||||
self.collect_step += cur_step
|
self.collect_step += cur_step
|
||||||
|
@ -35,6 +35,8 @@ class PGPolicy(BasePolicy):
|
|||||||
discount factor, :math:`\gamma \in [0, 1]`.
|
discount factor, :math:`\gamma \in [0, 1]`.
|
||||||
"""
|
"""
|
||||||
batch.returns = self._vanilla_returns(batch)
|
batch.returns = self._vanilla_returns(batch)
|
||||||
|
if getattr(batch, 'obs_next', None) is None:
|
||||||
|
batch.obs_next = buffer[(indice + 1) % len(buffer)].obs
|
||||||
# batch.returns = self._vectorized_returns(batch)
|
# batch.returns = self._vectorized_returns(batch)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user