unify single-env and multi-env in collector (#157)

Unify the implementation with multi-environments (wrap a single environment in a multi-environment with one envs) to greatly simplify the code.

This changed the behavior of single-environment.
Prior to this pr, for single environment, collector.collect(n_step=n) will step n steps.
After this pr, for single environment, collector.collect(n_step=n) will step m episodes until the steps are greater than n.

That is to say, collectors now always collect full episodes.
This commit is contained in:
youkaichao 2020-07-23 16:40:53 +08:00 committed by GitHub
parent 352a518399
commit bfeffe1f97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 193 deletions

View File

@ -130,7 +130,7 @@ In short, :class:`~tianshou.data.Collector` has two main methods:
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer;
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.
The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.

View File

@ -179,7 +179,7 @@ Train a Policy with Customized Codes
Tianshou supports user-defined training code. Here is the code snippet:
::
# pre-collect 5000 frames with random action before training
# pre-collect at least 5000 frames with random action before training
policy.set_eps(1)
train_collector.collect(n_step=5000)

View File

@ -25,29 +25,39 @@ class MyPolicy(BasePolicy):
pass
def preprocess_fn(**kwargs):
# modify info before adding into the buffer
class Logger:
def __init__(self, writer):
self.cnt = 0
self.writer = writer
def preprocess_fn(self, **kwargs):
# modify info before adding into the buffer, and recorded into tfb
# if info is not provided from env, it will be a ``Batch()``.
if not kwargs.get('info', Batch()).is_empty():
n = len(kwargs['obs'])
info = kwargs['info']
for i in range(n):
info[i].update(rew=kwargs['rew'][i])
return {'info': info}
# or: return Batch(info=info)
self.writer.add_scalar('key', np.mean(
info['key']), global_step=self.cnt)
self.cnt += 1
return Batch(info=info)
# or: return {'info': info}
else:
return Batch()
class Logger(object):
def __init__(self, writer):
self.cnt = 0
self.writer = writer
def log(self, info):
self.writer.add_scalar(
'key', np.mean(info['key']), global_step=self.cnt)
self.cnt += 1
@staticmethod
def single_preprocess_fn(**kwargs):
# same as above, without tfb
if not kwargs.get('info', Batch()).is_empty():
n = len(kwargs['obs'])
info = kwargs['info']
for i in range(n):
info[i].update(rew=kwargs['rew'][i])
return Batch(info=info)
# or: return {'info': info}
else:
return Batch()
def test_collector():
@ -60,16 +70,16 @@ def test_collector():
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn)
c0.collect(n_step=3, log_fn=logger.log)
assert np.allclose(c0.buffer.obs[:3], [0, 1, 0])
assert np.allclose(c0.buffer[:3].obs_next, [1, 2, 1])
c0.collect(n_episode=3, log_fn=logger.log)
assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
logger.preprocess_fn)
c0.collect(n_step=3)
assert np.allclose(c0.buffer.obs[:4], [0, 1, 0, 1])
assert np.allclose(c0.buffer[:4].obs_next, [1, 2, 1, 2])
c0.collect(n_episode=3)
assert np.allclose(c0.buffer.obs[:10], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
assert np.allclose(c0.buffer[:10].obs_next, [1, 2, 1, 2, 1, 2, 1, 2, 1, 2])
c0.collect(n_step=3, random=True)
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn)
logger.preprocess_fn)
c1.collect(n_step=6)
assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
assert np.allclose(c1.buffer[:11].obs_next,
@ -80,7 +90,7 @@ def test_collector():
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
c1.collect(n_episode=3, random=True)
c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn)
logger.preprocess_fn)
c2.collect(n_episode=[1, 2, 2, 2])
assert np.allclose(c2.buffer.obs_next[:26], [
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
@ -96,13 +106,15 @@ def test_collector():
def test_collector_with_dict_state():
env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True)
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
c0 = Collector(policy, env, ReplayBuffer(size=100),
Logger.single_preprocess_fn)
c0.collect(n_step=3)
c0.collect(n_episode=3)
c0.collect(n_episode=2)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
for i in [2, 3, 4, 5]]
envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
c1 = Collector(policy, envs, ReplayBuffer(size=100),
Logger.single_preprocess_fn)
c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2])
batch = c1.sample(10)
@ -113,7 +125,7 @@ def test_collector_with_dict_state():
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
preprocess_fn)
Logger.single_preprocess_fn)
c2.collect(n_episode=[0, 0, 0, 10])
batch = c2.sample(10)
print(batch['obs_next']['index'])
@ -125,16 +137,17 @@ def test_collector_with_ma():
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100),
preprocess_fn, reward_metric=reward_metric)
Logger.single_preprocess_fn, reward_metric=reward_metric)
# n_step=3 will collect a full episode
r = c0.collect(n_step=3)['rew']
assert np.asanyarray(r).size == 1 and r == 0.
r = c0.collect(n_episode=3)['rew']
assert np.asanyarray(r).size == 1 and r == 4.
r = c0.collect(n_episode=2)['rew']
assert np.asanyarray(r).size == 1 and r == 4.
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
for i in [2, 3, 4, 5]]
envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100),
preprocess_fn, reward_metric=reward_metric)
Logger.single_preprocess_fn, reward_metric=reward_metric)
r = c1.collect(n_step=10)['rew']
assert np.asanyarray(r).size == 1 and r == 4.
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
@ -153,7 +166,7 @@ def test_collector_with_ma():
assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
[[x] * 4 for x in rew])
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
preprocess_fn, reward_metric=reward_metric)
Logger.single_preprocess_fn, reward_metric=reward_metric)
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
assert np.asanyarray(r).size == 1 and r == 4.
batch = c2.sample(10)

View File

@ -5,8 +5,7 @@ import warnings
import numpy as np
from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.utils import MovAvg
from tianshou.env import BaseVectorEnv
from tianshou.env import BaseVectorEnv, VectorEnv
from tianshou.policy import BasePolicy
from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
@ -21,14 +20,10 @@ class Collector(object):
:param env: a ``gym.Env`` environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
``None``, it will automatically assign a small-size
:class:`~tianshou.data.ReplayBuffer`.
class. If set to ``None`` (testing phase), it will not store the data.
:param function preprocess_fn: a function called before the data has been
added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
to ``None``.
:param int stat_size: for the moving average of recording speed, defaults
to 100.
:param BaseNoise action_noise: add a noise to continuous action. Normally
a policy already has a noise param for exploration in training phase,
so this is recommended to use in test collector for some purpose.
@ -56,12 +51,9 @@ class Collector(object):
# the collector supports vectorized environments as well
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
buffers = [ReplayBuffer(size=5000) for _ in range(3)]
# you can also pass a list of replay buffer to collector, for multi-env
# collector = Collector(policy, envs, buffer=buffers)
collector = Collector(policy, envs, buffer=replay_buffer)
# collect at least 3 episodes
# collect 3 episodes
collector.collect(n_episode=3)
# collect 1 episode for the first env, 3 for the third env
collector.collect(n_episode=[1, 0, 3])
@ -81,9 +73,10 @@ class Collector(object):
# clear the buffer
collector.reset_buffer()
For the scenario of collecting data from multiple environments to a single
buffer, the cache buffers will turn on automatically. It may return the
data more than the given limitation.
Collected data always consist of full episodes. So if only ``n_step``
argument is give, the collector may return the data more than the
``n_step`` limitation. Same as ``n_episode`` for the multiple environment
case.
.. note::
@ -95,28 +88,22 @@ class Collector(object):
env: Union[gym.Env, BaseVectorEnv],
buffer: Optional[ReplayBuffer] = None,
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
stat_size: Optional[int] = 100,
action_noise: Optional[BaseNoise] = None,
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
) -> None:
super().__init__()
if not isinstance(env, BaseVectorEnv):
env = VectorEnv([lambda: env])
self.env = env
self.env_num = 1
self.env_num = len(env)
# need cache buffers before storing in the main buffer
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
self.buffer = buffer
self.policy = policy
self.preprocess_fn = preprocess_fn
self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv)
# need multiple cache buffers only if storing in one buffer
self._cached_buf = []
if self._multi_env:
self.env_num = len(env)
self._cached_buf = [ListReplayBuffer()
for _ in range(self.env_num)]
self.stat_size = stat_size
self._action_noise = action_noise
self._rew_metric = reward_metric or Collector._default_rew_metric
self.reset()
@ -135,8 +122,6 @@ class Collector(object):
obs_next={}, policy={})
self.reset_env()
self.reset_buffer()
self.step_speed = MovAvg(self.stat_size)
self.episode_speed = MovAvg(self.stat_size)
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
if self._action_noise is not None:
self._action_noise.reset()
@ -155,13 +140,9 @@ class Collector(object):
buffers (if need).
"""
obs = self.env.reset()
if not self._multi_env:
obs = self._make_batch(obs)
if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs).get('obs', obs)
self.data.obs = obs
self.reward = 0. # will be specified when the first data is ready
self.length = np.zeros(self.env_num)
for b in self._cached_buf:
b.reset()
@ -177,13 +158,6 @@ class Collector(object):
"""Close the environment(s)."""
self.env.close()
def _make_batch(self, data: Any) -> np.ndarray:
"""Return [data]."""
if isinstance(data, np.ndarray):
return data[None]
else:
return np.array([data])
def _reset_state(self, id: Union[int, List[int]]) -> None:
"""Reset self.data.state[id]."""
state = self.data.state # it is a reference
@ -195,24 +169,22 @@ class Collector(object):
state.empty_(id)
def collect(self,
n_step: int = 0,
n_episode: Union[int, List[int]] = 0,
n_step: Optional[int] = None,
n_episode: Optional[Union[int, List[int]]] = None,
random: bool = False,
render: Optional[float] = None,
log_fn: Optional[Callable[[dict], None]] = None
) -> Dict[str, float]:
"""Collect a specified number of step or episode.
:param int n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect (in each
environment).
:type n_episode: int or list
:param n_episode: how many episodes you want to collect. If it is an
int, it means to collect at lease ``n_episode`` episodes; if it is
a list, it means to collect exactly ``n_episode[i]`` episodes in
the i-th environment
:param bool random: whether to use random policy for collecting data,
defaults to ``False``.
:param float render: the sleep time between rendering consecutive
frames, defaults to ``None`` (no rendering).
:param function log_fn: a function which receives env info, typically
for tensorboard logging.
.. note::
@ -228,15 +200,15 @@ class Collector(object):
* ``rew`` the mean reward over collected episodes.
* ``len`` the mean length over collected episodes.
"""
if not self._multi_env:
n_episode = np.sum(n_episode)
start_time = time.time()
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
assert (n_step and not n_episode) or (not n_step and n_episode), \
"One and only one collection number specification is permitted!"
cur_step, cur_episode = 0, np.zeros(self.env_num)
reward_sum, length_sum = 0., 0
start_time = time.time()
step_count = 0
# episode of each environment
episode_count = np.zeros(self.env_num)
reward_total = 0.0
while True:
if cur_step >= 100000 and cur_episode.sum() == 0:
if step_count >= 100000 and episode_count.sum() == 0:
warnings.warn(
'There are already many steps in an episode. '
'You should add a time limitation to your environment!',
@ -250,11 +222,8 @@ class Collector(object):
# calculate the next action
if random:
action_space = self.env.action_space
if isinstance(action_space, list):
result = Batch(act=[a.sample() for a in action_space])
else:
result = Batch(act=self._make_batch(action_space.sample()))
result = Batch(
act=[a.sample() for a in self.env.action_space])
else:
with torch.no_grad():
result = self.policy(self.data, last_state)
@ -274,48 +243,29 @@ class Collector(object):
self.data.act += self._action_noise(self.data.act.shape)
# step in env
obs_next, rew, done, info = self.env.step(
self.data.act if self._multi_env else self.data.act[0])
obs_next, rew, done, info = self.env.step(self.data.act)
# move data to self.data
if not self._multi_env:
obs_next = self._make_batch(obs_next)
rew = self._make_batch(rew)
done = self._make_batch(done)
info = self._make_batch(info)
self.data.obs_next = obs_next
self.data.rew = rew
self.data.done = done
self.data.info = info
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if log_fn:
log_fn(info if self._multi_env else info[0])
if render:
self.render()
if render > 0:
time.sleep(render)
# add data into the buffer
self.length += 1
self.reward += self.data.rew
if self.preprocess_fn:
result = self.preprocess_fn(**self.data)
self.data.update(result)
if self._multi_env: # cache_buffer branch
for i in range(self.env_num):
self._cached_buf[i].add(**self.data[i])
if self.data.done[i]:
if n_step != 0 or np.isscalar(n_episode) or \
cur_episode[i] < n_episode[i]:
cur_episode[i] += 1
reward_sum += self.reward[i]
length_sum += self.length[i]
if self._cached_buf:
cur_step += len(self._cached_buf[i])
if n_step or np.isscalar(n_episode) or \
episode_count[i] < n_episode[i]:
episode_count[i] += 1
reward_total += np.sum(self._cached_buf[i].rew, axis=0)
step_count += len(self._cached_buf[i])
if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
self.reward[i], self.length[i] = 0., 0
if self._cached_buf:
self._cached_buf[i].reset()
self._reset_state(i)
obs_next = self.data.obs_next
@ -327,57 +277,35 @@ class Collector(object):
obs=obs_reset).get('obs', obs_reset)
else:
obs_next[env_ind] = obs_reset
self.data.obs_next = obs_next
if n_episode != 0:
self.data.obs = obs_next
if n_step:
if step_count >= n_step:
break
else:
if isinstance(n_episode, int) and \
episode_count.sum() >= n_episode:
break
if isinstance(n_episode, list) and \
(cur_episode >= np.array(n_episode)).all() or \
np.isscalar(n_episode) and \
cur_episode.sum() >= n_episode:
(episode_count >= n_episode).all():
break
else: # single buffer, without cache_buffer
if self.buffer is not None:
self.buffer.add(**self.data[0])
cur_step += 1
if self.data.done[0]:
cur_episode += 1
reward_sum += self.reward[0]
length_sum += self.length[0]
self.reward, self.length = 0., np.zeros(self.env_num)
self.data.state = Batch()
obs_next = self._make_batch(self.env.reset())
if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get(
'obs', obs_next)
self.data.obs_next = obs_next
if n_episode != 0 and cur_episode >= n_episode:
break
if n_step != 0 and cur_step >= n_step:
break
self.data.obs = self.data.obs_next
self.data.obs = self.data.obs_next
# generate the statistics
cur_episode = sum(cur_episode)
episode_count = sum(episode_count)
duration = max(time.time() - start_time, 1e-9)
self.step_speed.add(cur_step / duration)
self.episode_speed.add(cur_episode / duration)
self.collect_step += cur_step
self.collect_episode += cur_episode
self.collect_step += step_count
self.collect_episode += episode_count
self.collect_time += duration
if isinstance(n_episode, list):
n_episode = np.sum(n_episode)
else:
n_episode = max(cur_episode, 1)
reward_sum /= n_episode
if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum
reward_sum = self._rew_metric(reward_sum)
# average reward across the number of episodes
reward_avg = reward_total / episode_count
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
reward_avg = self._rew_metric(reward_avg)
return {
'n/ep': cur_episode,
'n/st': cur_step,
'v/st': self.step_speed.get(),
'v/ep': self.episode_speed.get(),
'rew': reward_sum,
'len': length_sum / n_episode,
'n/ep': episode_count,
'n/st': step_count,
'v/st': step_count / duration,
'v/ep': episode_count / duration,
'rew': reward_avg,
'len': step_count / episode_count,
}
def sample(self, batch_size: int) -> Batch:

View File

@ -23,7 +23,6 @@ def offpolicy_trainer(
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
log_fn: Optional[Callable[[dict], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
@ -61,7 +60,6 @@ def offpolicy_trainer(
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
:param function log_fn: a function receives env info for logging.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
:param int log_interval: the log interval of the writer.
@ -83,8 +81,7 @@ def offpolicy_trainer(
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_step=collect_per_step,
log_fn=log_fn)
result = train_collector.collect(n_step=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
test_result = test_episode(

View File

@ -23,7 +23,6 @@ def onpolicy_trainer(
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
log_fn: Optional[Callable[[dict], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
@ -62,7 +61,6 @@ def onpolicy_trainer(
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
:param function log_fn: a function receives env info for logging.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
:param int log_interval: the log interval of the writer.
@ -84,8 +82,7 @@ def onpolicy_trainer(
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_episode=collect_per_step,
log_fn=log_fn)
result = train_collector.collect(n_episode=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
test_result = test_episode(