unify single-env and multi-env in collector (#157)

Unify the implementation with multi-environments (wrap a single environment in a multi-environment with one envs) to greatly simplify the code.

This changed the behavior of single-environment.
Prior to this pr, for single environment, collector.collect(n_step=n) will step n steps.
After this pr, for single environment, collector.collect(n_step=n) will step m episodes until the steps are greater than n.

That is to say, collectors now always collect full episodes.
This commit is contained in:
youkaichao 2020-07-23 16:40:53 +08:00 committed by GitHub
parent 352a518399
commit bfeffe1f97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 193 deletions

View File

@ -130,7 +130,7 @@ In short, :class:`~tianshou.data.Collector` has two main methods:
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer; * :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer;
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. * :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number. The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.

View File

@ -179,7 +179,7 @@ Train a Policy with Customized Codes
Tianshou supports user-defined training code. Here is the code snippet: Tianshou supports user-defined training code. Here is the code snippet:
:: ::
# pre-collect 5000 frames with random action before training # pre-collect at least 5000 frames with random action before training
policy.set_eps(1) policy.set_eps(1)
train_collector.collect(n_step=5000) train_collector.collect(n_step=5000)

View File

@ -25,29 +25,39 @@ class MyPolicy(BasePolicy):
pass pass
def preprocess_fn(**kwargs): class Logger:
# modify info before adding into the buffer def __init__(self, writer):
self.cnt = 0
self.writer = writer
def preprocess_fn(self, **kwargs):
# modify info before adding into the buffer, and recorded into tfb
# if info is not provided from env, it will be a ``Batch()``. # if info is not provided from env, it will be a ``Batch()``.
if not kwargs.get('info', Batch()).is_empty(): if not kwargs.get('info', Batch()).is_empty():
n = len(kwargs['obs']) n = len(kwargs['obs'])
info = kwargs['info'] info = kwargs['info']
for i in range(n): for i in range(n):
info[i].update(rew=kwargs['rew'][i]) info[i].update(rew=kwargs['rew'][i])
return {'info': info} self.writer.add_scalar('key', np.mean(
# or: return Batch(info=info) info['key']), global_step=self.cnt)
self.cnt += 1
return Batch(info=info)
# or: return {'info': info}
else: else:
return Batch() return Batch()
@staticmethod
class Logger(object): def single_preprocess_fn(**kwargs):
def __init__(self, writer): # same as above, without tfb
self.cnt = 0 if not kwargs.get('info', Batch()).is_empty():
self.writer = writer n = len(kwargs['obs'])
info = kwargs['info']
def log(self, info): for i in range(n):
self.writer.add_scalar( info[i].update(rew=kwargs['rew'][i])
'key', np.mean(info['key']), global_step=self.cnt) return Batch(info=info)
self.cnt += 1 # or: return {'info': info}
else:
return Batch()
def test_collector(): def test_collector():
@ -60,16 +70,16 @@ def test_collector():
policy = MyPolicy() policy = MyPolicy()
env = env_fns[0]() env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn) logger.preprocess_fn)
c0.collect(n_step=3, log_fn=logger.log) c0.collect(n_step=3)
assert np.allclose(c0.buffer.obs[:3], [0, 1, 0]) assert np.allclose(c0.buffer.obs[:4], [0, 1, 0, 1])
assert np.allclose(c0.buffer[:3].obs_next, [1, 2, 1]) assert np.allclose(c0.buffer[:4].obs_next, [1, 2, 1, 2])
c0.collect(n_episode=3, log_fn=logger.log) c0.collect(n_episode=3)
assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) assert np.allclose(c0.buffer.obs[:10], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) assert np.allclose(c0.buffer[:10].obs_next, [1, 2, 1, 2, 1, 2, 1, 2, 1, 2])
c0.collect(n_step=3, random=True) c0.collect(n_step=3, random=True)
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn) logger.preprocess_fn)
c1.collect(n_step=6) c1.collect(n_step=6)
assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
assert np.allclose(c1.buffer[:11].obs_next, assert np.allclose(c1.buffer[:11].obs_next,
@ -80,7 +90,7 @@ def test_collector():
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
c1.collect(n_episode=3, random=True) c1.collect(n_episode=3, random=True)
c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False), c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn) logger.preprocess_fn)
c2.collect(n_episode=[1, 2, 2, 2]) c2.collect(n_episode=[1, 2, 2, 2])
assert np.allclose(c2.buffer.obs_next[:26], [ assert np.allclose(c2.buffer.obs_next[:26], [
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
@ -96,13 +106,15 @@ def test_collector():
def test_collector_with_dict_state(): def test_collector_with_dict_state():
env = MyTestEnv(size=5, sleep=0, dict_state=True) env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True) policy = MyPolicy(dict_state=True)
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn) c0 = Collector(policy, env, ReplayBuffer(size=100),
Logger.single_preprocess_fn)
c0.collect(n_step=3) c0.collect(n_step=3)
c0.collect(n_episode=3) c0.collect(n_episode=2)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
for i in [2, 3, 4, 5]] for i in [2, 3, 4, 5]]
envs = VectorEnv(env_fns) envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn) c1 = Collector(policy, envs, ReplayBuffer(size=100),
Logger.single_preprocess_fn)
c1.collect(n_step=10) c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2]) c1.collect(n_episode=[2, 1, 1, 2])
batch = c1.sample(10) batch = c1.sample(10)
@ -113,7 +125,7 @@ def test_collector_with_dict_state():
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
preprocess_fn) Logger.single_preprocess_fn)
c2.collect(n_episode=[0, 0, 0, 10]) c2.collect(n_episode=[0, 0, 0, 10])
batch = c2.sample(10) batch = c2.sample(10)
print(batch['obs_next']['index']) print(batch['obs_next']['index'])
@ -125,16 +137,17 @@ def test_collector_with_ma():
env = MyTestEnv(size=5, sleep=0, ma_rew=4) env = MyTestEnv(size=5, sleep=0, ma_rew=4)
policy = MyPolicy() policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100), c0 = Collector(policy, env, ReplayBuffer(size=100),
preprocess_fn, reward_metric=reward_metric) Logger.single_preprocess_fn, reward_metric=reward_metric)
# n_step=3 will collect a full episode
r = c0.collect(n_step=3)['rew'] r = c0.collect(n_step=3)['rew']
assert np.asanyarray(r).size == 1 and r == 0. assert np.asanyarray(r).size == 1 and r == 4.
r = c0.collect(n_episode=3)['rew'] r = c0.collect(n_episode=2)['rew']
assert np.asanyarray(r).size == 1 and r == 4. assert np.asanyarray(r).size == 1 and r == 4.
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
for i in [2, 3, 4, 5]] for i in [2, 3, 4, 5]]
envs = VectorEnv(env_fns) envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100), c1 = Collector(policy, envs, ReplayBuffer(size=100),
preprocess_fn, reward_metric=reward_metric) Logger.single_preprocess_fn, reward_metric=reward_metric)
r = c1.collect(n_step=10)['rew'] r = c1.collect(n_step=10)['rew']
assert np.asanyarray(r).size == 1 and r == 4. assert np.asanyarray(r).size == 1 and r == 4.
r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
@ -153,7 +166,7 @@ def test_collector_with_ma():
assert np.allclose(c0.buffer[:len(c0.buffer)].rew, assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
[[x] * 4 for x in rew]) [[x] * 4 for x in rew])
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
preprocess_fn, reward_metric=reward_metric) Logger.single_preprocess_fn, reward_metric=reward_metric)
r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
assert np.asanyarray(r).size == 1 and r == 4. assert np.asanyarray(r).size == 1 and r == 4.
batch = c2.sample(10) batch = c2.sample(10)

View File

@ -5,8 +5,7 @@ import warnings
import numpy as np import numpy as np
from typing import Any, Dict, List, Union, Optional, Callable from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.utils import MovAvg from tianshou.env import BaseVectorEnv, VectorEnv
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.exploration import BaseNoise from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
@ -21,14 +20,10 @@ class Collector(object):
:param env: a ``gym.Env`` environment or an instance of the :param env: a ``gym.Env`` environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class. :class:`~tianshou.env.BaseVectorEnv` class.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to class. If set to ``None`` (testing phase), it will not store the data.
``None``, it will automatically assign a small-size
:class:`~tianshou.data.ReplayBuffer`.
:param function preprocess_fn: a function called before the data has been :param function preprocess_fn: a function called before the data has been
added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
to ``None``. to ``None``.
:param int stat_size: for the moving average of recording speed, defaults
to 100.
:param BaseNoise action_noise: add a noise to continuous action. Normally :param BaseNoise action_noise: add a noise to continuous action. Normally
a policy already has a noise param for exploration in training phase, a policy already has a noise param for exploration in training phase,
so this is recommended to use in test collector for some purpose. so this is recommended to use in test collector for some purpose.
@ -56,12 +51,9 @@ class Collector(object):
# the collector supports vectorized environments as well # the collector supports vectorized environments as well
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
buffers = [ReplayBuffer(size=5000) for _ in range(3)]
# you can also pass a list of replay buffer to collector, for multi-env
# collector = Collector(policy, envs, buffer=buffers)
collector = Collector(policy, envs, buffer=replay_buffer) collector = Collector(policy, envs, buffer=replay_buffer)
# collect at least 3 episodes # collect 3 episodes
collector.collect(n_episode=3) collector.collect(n_episode=3)
# collect 1 episode for the first env, 3 for the third env # collect 1 episode for the first env, 3 for the third env
collector.collect(n_episode=[1, 0, 3]) collector.collect(n_episode=[1, 0, 3])
@ -81,9 +73,10 @@ class Collector(object):
# clear the buffer # clear the buffer
collector.reset_buffer() collector.reset_buffer()
For the scenario of collecting data from multiple environments to a single Collected data always consist of full episodes. So if only ``n_step``
buffer, the cache buffers will turn on automatically. It may return the argument is give, the collector may return the data more than the
data more than the given limitation. ``n_step`` limitation. Same as ``n_episode`` for the multiple environment
case.
.. note:: .. note::
@ -95,28 +88,22 @@ class Collector(object):
env: Union[gym.Env, BaseVectorEnv], env: Union[gym.Env, BaseVectorEnv],
buffer: Optional[ReplayBuffer] = None, buffer: Optional[ReplayBuffer] = None,
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
stat_size: Optional[int] = 100,
action_noise: Optional[BaseNoise] = None, action_noise: Optional[BaseNoise] = None,
reward_metric: Optional[Callable[[np.ndarray], float]] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if not isinstance(env, BaseVectorEnv):
env = VectorEnv([lambda: env])
self.env = env self.env = env
self.env_num = 1 self.env_num = len(env)
# need cache buffers before storing in the main buffer
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
self.buffer = buffer self.buffer = buffer
self.policy = policy self.policy = policy
self.preprocess_fn = preprocess_fn self.preprocess_fn = preprocess_fn
self.process_fn = policy.process_fn self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv)
# need multiple cache buffers only if storing in one buffer
self._cached_buf = []
if self._multi_env:
self.env_num = len(env)
self._cached_buf = [ListReplayBuffer()
for _ in range(self.env_num)]
self.stat_size = stat_size
self._action_noise = action_noise self._action_noise = action_noise
self._rew_metric = reward_metric or Collector._default_rew_metric self._rew_metric = reward_metric or Collector._default_rew_metric
self.reset() self.reset()
@ -135,8 +122,6 @@ class Collector(object):
obs_next={}, policy={}) obs_next={}, policy={})
self.reset_env() self.reset_env()
self.reset_buffer() self.reset_buffer()
self.step_speed = MovAvg(self.stat_size)
self.episode_speed = MovAvg(self.stat_size)
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
if self._action_noise is not None: if self._action_noise is not None:
self._action_noise.reset() self._action_noise.reset()
@ -155,13 +140,9 @@ class Collector(object):
buffers (if need). buffers (if need).
""" """
obs = self.env.reset() obs = self.env.reset()
if not self._multi_env:
obs = self._make_batch(obs)
if self.preprocess_fn: if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs).get('obs', obs) obs = self.preprocess_fn(obs=obs).get('obs', obs)
self.data.obs = obs self.data.obs = obs
self.reward = 0. # will be specified when the first data is ready
self.length = np.zeros(self.env_num)
for b in self._cached_buf: for b in self._cached_buf:
b.reset() b.reset()
@ -177,13 +158,6 @@ class Collector(object):
"""Close the environment(s).""" """Close the environment(s)."""
self.env.close() self.env.close()
def _make_batch(self, data: Any) -> np.ndarray:
"""Return [data]."""
if isinstance(data, np.ndarray):
return data[None]
else:
return np.array([data])
def _reset_state(self, id: Union[int, List[int]]) -> None: def _reset_state(self, id: Union[int, List[int]]) -> None:
"""Reset self.data.state[id].""" """Reset self.data.state[id]."""
state = self.data.state # it is a reference state = self.data.state # it is a reference
@ -195,24 +169,22 @@ class Collector(object):
state.empty_(id) state.empty_(id)
def collect(self, def collect(self,
n_step: int = 0, n_step: Optional[int] = None,
n_episode: Union[int, List[int]] = 0, n_episode: Optional[Union[int, List[int]]] = None,
random: bool = False, random: bool = False,
render: Optional[float] = None, render: Optional[float] = None,
log_fn: Optional[Callable[[dict], None]] = None
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Collect a specified number of step or episode. """Collect a specified number of step or episode.
:param int n_step: how many steps you want to collect. :param int n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect (in each :param n_episode: how many episodes you want to collect. If it is an
environment). int, it means to collect at lease ``n_episode`` episodes; if it is
:type n_episode: int or list a list, it means to collect exactly ``n_episode[i]`` episodes in
the i-th environment
:param bool random: whether to use random policy for collecting data, :param bool random: whether to use random policy for collecting data,
defaults to ``False``. defaults to ``False``.
:param float render: the sleep time between rendering consecutive :param float render: the sleep time between rendering consecutive
frames, defaults to ``None`` (no rendering). frames, defaults to ``None`` (no rendering).
:param function log_fn: a function which receives env info, typically
for tensorboard logging.
.. note:: .. note::
@ -228,15 +200,15 @@ class Collector(object):
* ``rew`` the mean reward over collected episodes. * ``rew`` the mean reward over collected episodes.
* ``len`` the mean length over collected episodes. * ``len`` the mean length over collected episodes.
""" """
if not self._multi_env: assert (n_step and not n_episode) or (not n_step and n_episode), \
n_episode = np.sum(n_episode)
start_time = time.time()
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
"One and only one collection number specification is permitted!" "One and only one collection number specification is permitted!"
cur_step, cur_episode = 0, np.zeros(self.env_num) start_time = time.time()
reward_sum, length_sum = 0., 0 step_count = 0
# episode of each environment
episode_count = np.zeros(self.env_num)
reward_total = 0.0
while True: while True:
if cur_step >= 100000 and cur_episode.sum() == 0: if step_count >= 100000 and episode_count.sum() == 0:
warnings.warn( warnings.warn(
'There are already many steps in an episode. ' 'There are already many steps in an episode. '
'You should add a time limitation to your environment!', 'You should add a time limitation to your environment!',
@ -250,11 +222,8 @@ class Collector(object):
# calculate the next action # calculate the next action
if random: if random:
action_space = self.env.action_space result = Batch(
if isinstance(action_space, list): act=[a.sample() for a in self.env.action_space])
result = Batch(act=[a.sample() for a in action_space])
else:
result = Batch(act=self._make_batch(action_space.sample()))
else: else:
with torch.no_grad(): with torch.no_grad():
result = self.policy(self.data, last_state) result = self.policy(self.data, last_state)
@ -274,48 +243,29 @@ class Collector(object):
self.data.act += self._action_noise(self.data.act.shape) self.data.act += self._action_noise(self.data.act.shape)
# step in env # step in env
obs_next, rew, done, info = self.env.step( obs_next, rew, done, info = self.env.step(self.data.act)
self.data.act if self._multi_env else self.data.act[0])
# move data to self.data # move data to self.data
if not self._multi_env: self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
obs_next = self._make_batch(obs_next)
rew = self._make_batch(rew)
done = self._make_batch(done)
info = self._make_batch(info)
self.data.obs_next = obs_next
self.data.rew = rew
self.data.done = done
self.data.info = info
if log_fn:
log_fn(info if self._multi_env else info[0])
if render: if render:
self.render() self.render()
if render > 0:
time.sleep(render) time.sleep(render)
# add data into the buffer # add data into the buffer
self.length += 1
self.reward += self.data.rew
if self.preprocess_fn: if self.preprocess_fn:
result = self.preprocess_fn(**self.data) result = self.preprocess_fn(**self.data)
self.data.update(result) self.data.update(result)
if self._multi_env: # cache_buffer branch
for i in range(self.env_num): for i in range(self.env_num):
self._cached_buf[i].add(**self.data[i]) self._cached_buf[i].add(**self.data[i])
if self.data.done[i]: if self.data.done[i]:
if n_step != 0 or np.isscalar(n_episode) or \ if n_step or np.isscalar(n_episode) or \
cur_episode[i] < n_episode[i]: episode_count[i] < n_episode[i]:
cur_episode[i] += 1 episode_count[i] += 1
reward_sum += self.reward[i] reward_total += np.sum(self._cached_buf[i].rew, axis=0)
length_sum += self.length[i] step_count += len(self._cached_buf[i])
if self._cached_buf:
cur_step += len(self._cached_buf[i])
if self.buffer is not None: if self.buffer is not None:
self.buffer.update(self._cached_buf[i]) self.buffer.update(self._cached_buf[i])
self.reward[i], self.length[i] = 0., 0
if self._cached_buf:
self._cached_buf[i].reset() self._cached_buf[i].reset()
self._reset_state(i) self._reset_state(i)
obs_next = self.data.obs_next obs_next = self.data.obs_next
@ -327,57 +277,35 @@ class Collector(object):
obs=obs_reset).get('obs', obs_reset) obs=obs_reset).get('obs', obs_reset)
else: else:
obs_next[env_ind] = obs_reset obs_next[env_ind] = obs_reset
self.data.obs_next = obs_next self.data.obs = obs_next
if n_episode != 0: if n_step:
if step_count >= n_step:
break
else:
if isinstance(n_episode, int) and \
episode_count.sum() >= n_episode:
break
if isinstance(n_episode, list) and \ if isinstance(n_episode, list) and \
(cur_episode >= np.array(n_episode)).all() or \ (episode_count >= n_episode).all():
np.isscalar(n_episode) and \
cur_episode.sum() >= n_episode:
break break
else: # single buffer, without cache_buffer
if self.buffer is not None:
self.buffer.add(**self.data[0])
cur_step += 1
if self.data.done[0]:
cur_episode += 1
reward_sum += self.reward[0]
length_sum += self.length[0]
self.reward, self.length = 0., np.zeros(self.env_num)
self.data.state = Batch()
obs_next = self._make_batch(self.env.reset())
if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get(
'obs', obs_next)
self.data.obs_next = obs_next
if n_episode != 0 and cur_episode >= n_episode:
break
if n_step != 0 and cur_step >= n_step:
break
self.data.obs = self.data.obs_next
self.data.obs = self.data.obs_next
# generate the statistics # generate the statistics
cur_episode = sum(cur_episode) episode_count = sum(episode_count)
duration = max(time.time() - start_time, 1e-9) duration = max(time.time() - start_time, 1e-9)
self.step_speed.add(cur_step / duration) self.collect_step += step_count
self.episode_speed.add(cur_episode / duration) self.collect_episode += episode_count
self.collect_step += cur_step
self.collect_episode += cur_episode
self.collect_time += duration self.collect_time += duration
if isinstance(n_episode, list): # average reward across the number of episodes
n_episode = np.sum(n_episode) reward_avg = reward_total / episode_count
else: if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
n_episode = max(cur_episode, 1) reward_avg = self._rew_metric(reward_avg)
reward_sum /= n_episode
if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum
reward_sum = self._rew_metric(reward_sum)
return { return {
'n/ep': cur_episode, 'n/ep': episode_count,
'n/st': cur_step, 'n/st': step_count,
'v/st': self.step_speed.get(), 'v/st': step_count / duration,
'v/ep': self.episode_speed.get(), 'v/ep': episode_count / duration,
'rew': reward_sum, 'rew': reward_avg,
'len': length_sum / n_episode, 'len': step_count / episode_count,
} }
def sample(self, batch_size: int) -> Batch: def sample(self, batch_size: int) -> Batch:

View File

@ -23,7 +23,6 @@ def offpolicy_trainer(
test_fn: Optional[Callable[[int], None]] = None, test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None, stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
log_fn: Optional[Callable[[dict], None]] = None,
writer: Optional[SummaryWriter] = None, writer: Optional[SummaryWriter] = None,
log_interval: int = 1, log_interval: int = 1,
verbose: bool = True, verbose: bool = True,
@ -61,7 +60,6 @@ def offpolicy_trainer(
:param function stop_fn: a function receives the average undiscounted :param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether returns of the testing result, return a boolean which indicates whether
reaching the goal. reaching the goal.
:param function log_fn: a function receives env info for logging.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter. SummaryWriter.
:param int log_interval: the log interval of the writer. :param int log_interval: the log interval of the writer.
@ -83,8 +81,7 @@ def offpolicy_trainer(
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t: **tqdm_config) as t:
while t.n < t.total: while t.n < t.total:
result = train_collector.collect(n_step=collect_per_step, result = train_collector.collect(n_step=collect_per_step)
log_fn=log_fn)
data = {} data = {}
if test_in_train and stop_fn and stop_fn(result['rew']): if test_in_train and stop_fn and stop_fn(result['rew']):
test_result = test_episode( test_result = test_episode(

View File

@ -23,7 +23,6 @@ def onpolicy_trainer(
test_fn: Optional[Callable[[int], None]] = None, test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None, stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
log_fn: Optional[Callable[[dict], None]] = None,
writer: Optional[SummaryWriter] = None, writer: Optional[SummaryWriter] = None,
log_interval: int = 1, log_interval: int = 1,
verbose: bool = True, verbose: bool = True,
@ -62,7 +61,6 @@ def onpolicy_trainer(
:param function stop_fn: a function receives the average undiscounted :param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether returns of the testing result, return a boolean which indicates whether
reaching the goal. reaching the goal.
:param function log_fn: a function receives env info for logging.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter. SummaryWriter.
:param int log_interval: the log interval of the writer. :param int log_interval: the log interval of the writer.
@ -84,8 +82,7 @@ def onpolicy_trainer(
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t: **tqdm_config) as t:
while t.n < t.total: while t.n < t.total:
result = train_collector.collect(n_episode=collect_per_step, result = train_collector.collect(n_episode=collect_per_step)
log_fn=log_fn)
data = {} data = {}
if test_in_train and stop_fn and stop_fn(result['rew']): if test_in_train and stop_fn and stop_fn(result['rew']):
test_result = test_episode( test_result = test_episode(