From 26a6cca76e8c10353b898bc60ba9500df6613d5f Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 5 May 2024 15:56:06 +0200 Subject: [PATCH] Improved docstrings, added asserts to make mypy happy --- tianshou/data/collector.py | 6 ------ tianshou/trainer/base.py | 14 +++++++++----- tianshou/trainer/utils.py | 20 +++----------------- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index cfc4b3d..b498f45 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -279,7 +279,6 @@ class BaseCollector(ABC): :param n_episode: how many episodes you want to collect. :param random: whether to use random policy for collecting data. :param render: the sleep time between rendering consecutive frames. - :param no_grad: whether to retain gradient in policy.forward(). :param reset_before_collect: whether to reset the environment before collecting data. (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's @@ -343,11 +342,6 @@ class Collector(BaseCollector): # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. - - # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy - # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on - # policy.deterministic_eval) - def __init__( self, policy: BasePolicy, diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 6e46374..ef9a154 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -2,7 +2,7 @@ import logging import time from abc import ABC, abstractmethod from collections import defaultdict, deque -from collections.abc import Callable +from collections.abc import Callable, Iterator from contextlib import contextmanager from dataclasses import asdict @@ -360,7 +360,10 @@ class BaseTrainer(ABC): self.logger.log_info_data(asdict(info_stat), self.epoch) # in case trainer is used with run(), epoch_stat will not be returned - epoch_stat: EpochStats = EpochStats( + assert ( + update_stat is not None + ), "Defined in the loop above, this shouldn't have happened and is likely a bug!" + return EpochStats( epoch=self.epoch, train_collect_stat=train_stat, test_collect_stat=test_stat, @@ -368,8 +371,6 @@ class BaseTrainer(ABC): info_stat=info_stat, ) - return epoch_stat - def test_step(self) -> tuple[CollectStats, bool]: """Perform one testing step.""" assert self.episode_per_test is not None @@ -407,7 +408,7 @@ class BaseTrainer(ABC): return test_stat, stop_fn_flag @contextmanager - def _is_within_training_step_enabled(self, is_within_training_step: bool): + def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]: old_value = self.policy.is_within_training_step try: self.policy.is_within_training_step = is_within_training_step @@ -419,10 +420,12 @@ class BaseTrainer(ABC): with self._is_within_training_step_enabled(True): should_stop_training = False + collect_stats: CollectStatsBase | CollectStats if self.train_collector is not None: collect_stats = self._collect_training_data() should_stop_training = self._test_in_train(collect_stats) else: + assert self.buffer is not None, "Either train_collector or buffer must be provided." collect_stats = CollectStatsBase( n_collected_episodes=len(self.buffer), ) @@ -484,6 +487,7 @@ class BaseTrainer(ABC): and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore ): assert self.test_collector is not None + assert self.episode_per_test is not None and self.episode_per_test > 0 test_result = test_episode( self.test_collector, self.test_fn, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 767e76d..de730ce 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -49,23 +49,9 @@ def gather_info( ) -> InfoStats: """A simple wrapper of gathering information from collectors. - :return: A dataclass object with the following members (depending on available collectors): - - * ``gradient_step`` the total number of gradient steps; - * ``best_reward`` the best reward over the test results; - * ``best_reward_std`` the standard deviation of best reward over the test results; - * ``train_step`` the total collected step of training collector; - * ``train_episode`` the total collected episode of training collector; - * ``test_step`` the total collected step of test collector; - * ``test_episode`` the total collected episode of test collector; - * ``timing`` the timing statistics, with the following members: - * ``total_time`` the total time elapsed; - * ``train_time`` the total time elapsed for learning training (collecting samples plus model update); - * ``train_time_collect`` the time for collecting transitions in the \ - training collector; - * ``train_time_update`` the time for training models; - * ``test_time`` the time for testing; - * ``update_speed`` the speed of updating (env_step per second). + :return: InfoStats object with times computed based on the `start_time` and + episode/step counts read off the collectors. No computation of + expensive statistics is done here. """ duration = max(0.0, time.time() - start_time) test_time = 0.0