diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index b1ce836..b773186 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -82,7 +82,8 @@ class EpochStats(DataclassPPrintMixin): """The statistics of the last call to the training collector.""" test_collect_stat: Optional["CollectStats"] """The statistics of the last call to the test collector.""" - training_stat: "TrainingStats" - """The statistics of the last model update step.""" + training_stat: Optional["TrainingStats"] + """The statistics of the last model update step. + Can be None if no model update is performed, typically in the last training iteration.""" info_stat: InfoStats """The information of the collector.""" diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index ef9a154..3ad80c8 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -360,9 +360,6 @@ class BaseTrainer(ABC): self.logger.log_info_data(asdict(info_stat), self.epoch) # in case trainer is used with run(), epoch_stat will not be returned - assert ( - update_stat is not None - ), "Defined in the loop above, this shouldn't have happened and is likely a bug!" return EpochStats( epoch=self.epoch, train_collect_stat=train_stat, @@ -417,13 +414,23 @@ class BaseTrainer(ABC): self.policy.is_within_training_step = old_value def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: + """Perform one training iteration. + + A training iteration includes collecting data (for online RL), determining whether to stop training, + and peforming a policy update if the training iteration should continue. + + :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. + If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. + """ with self._is_within_training_step_enabled(True): should_stop_training = False collect_stats: CollectStatsBase | CollectStats if self.train_collector is not None: collect_stats = self._collect_training_data() - should_stop_training = self._test_in_train(collect_stats) + should_stop_training = self._update_best_reward_and_return_should_stop_training( + collect_stats, + ) else: assert self.buffer is not None, "Either train_collector or buffer must be provided." collect_stats = CollectStatsBase( @@ -467,11 +474,17 @@ class BaseTrainer(ABC): return collect_stats - def _test_in_train(self, collect_stats: CollectStats) -> bool: - """If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. - Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return + def _update_best_reward_and_return_should_stop_training( + self, + collect_stats: CollectStats, + ) -> bool: + """If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data. + Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return on it. - Finally, if the latter is also True, will set should_stop_training to True. + Finally, if the latter is also True, will return True. + + **NOTE:** has a side effect of updating the best reward and corresponding std. + :param collect_stats: the data collection stats :return: flag indicating whether to stop training