From bfeffe1f977d16d6b26c638e1eca8280ad53e2e5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 16:40:53 +0800 Subject: [PATCH] unify single-env and multi-env in collector (#157) Unify the implementation with multi-environments (wrap a single environment in a multi-environment with one envs) to greatly simplify the code. This changed the behavior of single-environment. Prior to this pr, for single environment, collector.collect(n_step=n) will step n steps. After this pr, for single environment, collector.collect(n_step=n) will step m episodes until the steps are greater than n. That is to say, collectors now always collect full episodes. --- docs/tutorials/concepts.rst | 2 +- docs/tutorials/dqn.rst | 2 +- test/base/test_collector.py | 87 ++++++++------ tianshou/data/collector.py | 220 ++++++++++++---------------------- tianshou/trainer/offpolicy.py | 5 +- tianshou/trainer/onpolicy.py | 5 +- 6 files changed, 128 insertions(+), 193 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 1ba271e..3f033ea 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -130,7 +130,7 @@ In short, :class:`~tianshou.data.Collector` has two main methods: * :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer; * :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. -Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. +Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number. diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index cd663e5..9cbb243 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -179,7 +179,7 @@ Train a Policy with Customized Codes Tianshou supports user-defined training code. Here is the code snippet: :: - # pre-collect 5000 frames with random action before training + # pre-collect at least 5000 frames with random action before training policy.set_eps(1) train_collector.collect(n_step=5000) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index ead017a..9fa37b6 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -25,29 +25,39 @@ class MyPolicy(BasePolicy): pass -def preprocess_fn(**kwargs): - # modify info before adding into the buffer - # if info is not provided from env, it will be a ``Batch()``. - if not kwargs.get('info', Batch()).is_empty(): - n = len(kwargs['obs']) - info = kwargs['info'] - for i in range(n): - info[i].update(rew=kwargs['rew'][i]) - return {'info': info} - # or: return Batch(info=info) - else: - return Batch() - - -class Logger(object): +class Logger: def __init__(self, writer): self.cnt = 0 self.writer = writer - def log(self, info): - self.writer.add_scalar( - 'key', np.mean(info['key']), global_step=self.cnt) - self.cnt += 1 + def preprocess_fn(self, **kwargs): + # modify info before adding into the buffer, and recorded into tfb + # if info is not provided from env, it will be a ``Batch()``. + if not kwargs.get('info', Batch()).is_empty(): + n = len(kwargs['obs']) + info = kwargs['info'] + for i in range(n): + info[i].update(rew=kwargs['rew'][i]) + self.writer.add_scalar('key', np.mean( + info['key']), global_step=self.cnt) + self.cnt += 1 + return Batch(info=info) + # or: return {'info': info} + else: + return Batch() + + @staticmethod + def single_preprocess_fn(**kwargs): + # same as above, without tfb + if not kwargs.get('info', Batch()).is_empty(): + n = len(kwargs['obs']) + info = kwargs['info'] + for i in range(n): + info[i].update(rew=kwargs['rew'][i]) + return Batch(info=info) + # or: return {'info': info} + else: + return Batch() def test_collector(): @@ -60,16 +70,16 @@ def test_collector(): policy = MyPolicy() env = env_fns[0]() c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), - preprocess_fn) - c0.collect(n_step=3, log_fn=logger.log) - assert np.allclose(c0.buffer.obs[:3], [0, 1, 0]) - assert np.allclose(c0.buffer[:3].obs_next, [1, 2, 1]) - c0.collect(n_episode=3, log_fn=logger.log) - assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) - assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) + logger.preprocess_fn) + c0.collect(n_step=3) + assert np.allclose(c0.buffer.obs[:4], [0, 1, 0, 1]) + assert np.allclose(c0.buffer[:4].obs_next, [1, 2, 1, 2]) + c0.collect(n_episode=3) + assert np.allclose(c0.buffer.obs[:10], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) + assert np.allclose(c0.buffer[:10].obs_next, [1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) c0.collect(n_step=3, random=True) c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), - preprocess_fn) + logger.preprocess_fn) c1.collect(n_step=6) assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) assert np.allclose(c1.buffer[:11].obs_next, @@ -80,7 +90,7 @@ def test_collector(): [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) c1.collect(n_episode=3, random=True) c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False), - preprocess_fn) + logger.preprocess_fn) c2.collect(n_episode=[1, 2, 2, 2]) assert np.allclose(c2.buffer.obs_next[:26], [ 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, @@ -96,13 +106,15 @@ def test_collector(): def test_collector_with_dict_state(): env = MyTestEnv(size=5, sleep=0, dict_state=True) policy = MyPolicy(dict_state=True) - c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn) + c0 = Collector(policy, env, ReplayBuffer(size=100), + Logger.single_preprocess_fn) c0.collect(n_step=3) - c0.collect(n_episode=3) + c0.collect(n_episode=2) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) - c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn) + c1 = Collector(policy, envs, ReplayBuffer(size=100), + Logger.single_preprocess_fn) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) batch = c1.sample(10) @@ -113,7 +125,7 @@ def test_collector_with_dict_state(): 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - preprocess_fn) + Logger.single_preprocess_fn) c2.collect(n_episode=[0, 0, 0, 10]) batch = c2.sample(10) print(batch['obs_next']['index']) @@ -125,16 +137,17 @@ def test_collector_with_ma(): env = MyTestEnv(size=5, sleep=0, ma_rew=4) policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), - preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn, reward_metric=reward_metric) + # n_step=3 will collect a full episode r = c0.collect(n_step=3)['rew'] - assert np.asanyarray(r).size == 1 and r == 0. - r = c0.collect(n_episode=3)['rew'] + assert np.asanyarray(r).size == 1 and r == 4. + r = c0.collect(n_episode=2)['rew'] assert np.asanyarray(r).size == 1 and r == 4. env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) c1 = Collector(policy, envs, ReplayBuffer(size=100), - preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn, reward_metric=reward_metric) r = c1.collect(n_step=10)['rew'] assert np.asanyarray(r).size == 1 and r == 4. r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] @@ -153,7 +166,7 @@ def test_collector_with_ma(): assert np.allclose(c0.buffer[:len(c0.buffer)].rew, [[x] * 4 for x in rew]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn, reward_metric=reward_metric) r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] assert np.asanyarray(r).size == 1 and r == 4. batch = c2.sample(10) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 98c62da..fa5108b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -5,8 +5,7 @@ import warnings import numpy as np from typing import Any, Dict, List, Union, Optional, Callable -from tianshou.utils import MovAvg -from tianshou.env import BaseVectorEnv +from tianshou.env import BaseVectorEnv, VectorEnv from tianshou.policy import BasePolicy from tianshou.exploration import BaseNoise from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy @@ -21,14 +20,10 @@ class Collector(object): :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` - class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to - ``None``, it will automatically assign a small-size - :class:`~tianshou.data.ReplayBuffer`. + class. If set to ``None`` (testing phase), it will not store the data. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults to ``None``. - :param int stat_size: for the moving average of recording speed, defaults - to 100. :param BaseNoise action_noise: add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose. @@ -56,12 +51,9 @@ class Collector(object): # the collector supports vectorized environments as well envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) - buffers = [ReplayBuffer(size=5000) for _ in range(3)] - # you can also pass a list of replay buffer to collector, for multi-env - # collector = Collector(policy, envs, buffer=buffers) collector = Collector(policy, envs, buffer=replay_buffer) - # collect at least 3 episodes + # collect 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) @@ -81,9 +73,10 @@ class Collector(object): # clear the buffer collector.reset_buffer() - For the scenario of collecting data from multiple environments to a single - buffer, the cache buffers will turn on automatically. It may return the - data more than the given limitation. + Collected data always consist of full episodes. So if only ``n_step`` + argument is give, the collector may return the data more than the + ``n_step`` limitation. Same as ``n_episode`` for the multiple environment + case. .. note:: @@ -95,28 +88,22 @@ class Collector(object): env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, - stat_size: Optional[int] = 100, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: super().__init__() + if not isinstance(env, BaseVectorEnv): + env = VectorEnv([lambda: env]) self.env = env - self.env_num = 1 + self.env_num = len(env) + # need cache buffers before storing in the main buffer + self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn - self._multi_env = isinstance(env, BaseVectorEnv) - # need multiple cache buffers only if storing in one buffer - self._cached_buf = [] - if self._multi_env: - self.env_num = len(env) - self._cached_buf = [ListReplayBuffer() - for _ in range(self.env_num)] - self.stat_size = stat_size self._action_noise = action_noise - self._rew_metric = reward_metric or Collector._default_rew_metric self.reset() @@ -135,8 +122,6 @@ class Collector(object): obs_next={}, policy={}) self.reset_env() self.reset_buffer() - self.step_speed = MovAvg(self.stat_size) - self.episode_speed = MovAvg(self.stat_size) self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 if self._action_noise is not None: self._action_noise.reset() @@ -155,13 +140,9 @@ class Collector(object): buffers (if need). """ obs = self.env.reset() - if not self._multi_env: - obs = self._make_batch(obs) if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get('obs', obs) self.data.obs = obs - self.reward = 0. # will be specified when the first data is ready - self.length = np.zeros(self.env_num) for b in self._cached_buf: b.reset() @@ -177,13 +158,6 @@ class Collector(object): """Close the environment(s).""" self.env.close() - def _make_batch(self, data: Any) -> np.ndarray: - """Return [data].""" - if isinstance(data, np.ndarray): - return data[None] - else: - return np.array([data]) - def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset self.data.state[id].""" state = self.data.state # it is a reference @@ -195,24 +169,22 @@ class Collector(object): state.empty_(id) def collect(self, - n_step: int = 0, - n_episode: Union[int, List[int]] = 0, + n_step: Optional[int] = None, + n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, - log_fn: Optional[Callable[[dict], None]] = None ) -> Dict[str, float]: """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect (in each - environment). - :type n_episode: int or list + :param n_episode: how many episodes you want to collect. If it is an + int, it means to collect at lease ``n_episode`` episodes; if it is + a list, it means to collect exactly ``n_episode[i]`` episodes in + the i-th environment :param bool random: whether to use random policy for collecting data, defaults to ``False``. :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). - :param function log_fn: a function which receives env info, typically - for tensorboard logging. .. note:: @@ -228,15 +200,15 @@ class Collector(object): * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ - if not self._multi_env: - n_episode = np.sum(n_episode) - start_time = time.time() - assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ + assert (n_step and not n_episode) or (not n_step and n_episode), \ "One and only one collection number specification is permitted!" - cur_step, cur_episode = 0, np.zeros(self.env_num) - reward_sum, length_sum = 0., 0 + start_time = time.time() + step_count = 0 + # episode of each environment + episode_count = np.zeros(self.env_num) + reward_total = 0.0 while True: - if cur_step >= 100000 and cur_episode.sum() == 0: + if step_count >= 100000 and episode_count.sum() == 0: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', @@ -250,11 +222,8 @@ class Collector(object): # calculate the next action if random: - action_space = self.env.action_space - if isinstance(action_space, list): - result = Batch(act=[a.sample() for a in action_space]) - else: - result = Batch(act=self._make_batch(action_space.sample())) + result = Batch( + act=[a.sample() for a in self.env.action_space]) else: with torch.no_grad(): result = self.policy(self.data, last_state) @@ -274,110 +243,69 @@ class Collector(object): self.data.act += self._action_noise(self.data.act.shape) # step in env - obs_next, rew, done, info = self.env.step( - self.data.act if self._multi_env else self.data.act[0]) + obs_next, rew, done, info = self.env.step(self.data.act) # move data to self.data - if not self._multi_env: - obs_next = self._make_batch(obs_next) - rew = self._make_batch(rew) - done = self._make_batch(done) - info = self._make_batch(info) - self.data.obs_next = obs_next - self.data.rew = rew - self.data.done = done - self.data.info = info + self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) - if log_fn: - log_fn(info if self._multi_env else info[0]) if render: self.render() - if render > 0: - time.sleep(render) + time.sleep(render) # add data into the buffer - self.length += 1 - self.reward += self.data.rew if self.preprocess_fn: result = self.preprocess_fn(**self.data) self.data.update(result) - if self._multi_env: # cache_buffer branch - for i in range(self.env_num): - self._cached_buf[i].add(**self.data[i]) - if self.data.done[i]: - if n_step != 0 or np.isscalar(n_episode) or \ - cur_episode[i] < n_episode[i]: - cur_episode[i] += 1 - reward_sum += self.reward[i] - length_sum += self.length[i] - if self._cached_buf: - cur_step += len(self._cached_buf[i]) - if self.buffer is not None: - self.buffer.update(self._cached_buf[i]) - self.reward[i], self.length[i] = 0., 0 - if self._cached_buf: - self._cached_buf[i].reset() - self._reset_state(i) - obs_next = self.data.obs_next - if sum(self.data.done): - env_ind = np.where(self.data.done)[0] - obs_reset = self.env.reset(env_ind) - if self.preprocess_fn: - obs_next[env_ind] = self.preprocess_fn( - obs=obs_reset).get('obs', obs_reset) - else: - obs_next[env_ind] = obs_reset - self.data.obs_next = obs_next - if n_episode != 0: - if isinstance(n_episode, list) and \ - (cur_episode >= np.array(n_episode)).all() or \ - np.isscalar(n_episode) and \ - cur_episode.sum() >= n_episode: - break - else: # single buffer, without cache_buffer - if self.buffer is not None: - self.buffer.add(**self.data[0]) - cur_step += 1 - if self.data.done[0]: - cur_episode += 1 - reward_sum += self.reward[0] - length_sum += self.length[0] - self.reward, self.length = 0., np.zeros(self.env_num) - self.data.state = Batch() - obs_next = self._make_batch(self.env.reset()) - if self.preprocess_fn: - obs_next = self.preprocess_fn(obs=obs_next).get( - 'obs', obs_next) - self.data.obs_next = obs_next - if n_episode != 0 and cur_episode >= n_episode: + for i in range(self.env_num): + self._cached_buf[i].add(**self.data[i]) + if self.data.done[i]: + if n_step or np.isscalar(n_episode) or \ + episode_count[i] < n_episode[i]: + episode_count[i] += 1 + reward_total += np.sum(self._cached_buf[i].rew, axis=0) + step_count += len(self._cached_buf[i]) + if self.buffer is not None: + self.buffer.update(self._cached_buf[i]) + self._cached_buf[i].reset() + self._reset_state(i) + obs_next = self.data.obs_next + if sum(self.data.done): + env_ind = np.where(self.data.done)[0] + obs_reset = self.env.reset(env_ind) + if self.preprocess_fn: + obs_next[env_ind] = self.preprocess_fn( + obs=obs_reset).get('obs', obs_reset) + else: + obs_next[env_ind] = obs_reset + self.data.obs = obs_next + if n_step: + if step_count >= n_step: + break + else: + if isinstance(n_episode, int) and \ + episode_count.sum() >= n_episode: + break + if isinstance(n_episode, list) and \ + (episode_count >= n_episode).all(): break - if n_step != 0 and cur_step >= n_step: - break - self.data.obs = self.data.obs_next - self.data.obs = self.data.obs_next # generate the statistics - cur_episode = sum(cur_episode) + episode_count = sum(episode_count) duration = max(time.time() - start_time, 1e-9) - self.step_speed.add(cur_step / duration) - self.episode_speed.add(cur_episode / duration) - self.collect_step += cur_step - self.collect_episode += cur_episode + self.collect_step += step_count + self.collect_episode += episode_count self.collect_time += duration - if isinstance(n_episode, list): - n_episode = np.sum(n_episode) - else: - n_episode = max(cur_episode, 1) - reward_sum /= n_episode - if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum - reward_sum = self._rew_metric(reward_sum) + # average reward across the number of episodes + reward_avg = reward_total / episode_count + if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg + reward_avg = self._rew_metric(reward_avg) return { - 'n/ep': cur_episode, - 'n/st': cur_step, - 'v/st': self.step_speed.get(), - 'v/ep': self.episode_speed.get(), - 'rew': reward_sum, - 'len': length_sum / n_episode, + 'n/ep': episode_count, + 'n/st': step_count, + 'v/st': step_count / duration, + 'v/ep': episode_count / duration, + 'rew': reward_avg, + 'len': step_count / episode_count, } def sample(self, batch_size: int) -> Batch: diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index edcf0fd..408d4e7 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -23,7 +23,6 @@ def offpolicy_trainer( test_fn: Optional[Callable[[int], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, - log_fn: Optional[Callable[[dict], None]] = None, writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, @@ -61,7 +60,6 @@ def offpolicy_trainer( :param function stop_fn: a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. - :param function log_fn: a function receives env info for logging. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter. :param int log_interval: the log interval of the writer. @@ -83,8 +81,7 @@ def offpolicy_trainer( with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: - result = train_collector.collect(n_step=collect_per_step, - log_fn=log_fn) + result = train_collector.collect(n_step=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result['rew']): test_result = test_episode( diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index b0d68ff..8f77dd1 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -23,7 +23,6 @@ def onpolicy_trainer( test_fn: Optional[Callable[[int], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, - log_fn: Optional[Callable[[dict], None]] = None, writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, @@ -62,7 +61,6 @@ def onpolicy_trainer( :param function stop_fn: a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. - :param function log_fn: a function receives env info for logging. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter. :param int log_interval: the log interval of the writer. @@ -84,8 +82,7 @@ def onpolicy_trainer( with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: - result = train_collector.collect(n_episode=collect_per_step, - log_fn=log_fn) + result = train_collector.collect(n_episode=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result['rew']): test_result = test_episode(