diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 84faf79..c6d87d1 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -4,7 +4,6 @@ from abc import ABC, abstractmethod from collections import defaultdict, deque from collections.abc import Callable from dataclasses import asdict -from typing import Any import numpy as np import tqdm @@ -312,7 +311,14 @@ class BaseTrainer(ABC): while t.n < t.total and not self.stop_fn_flag: train_stat: CollectStatsBase if self.train_collector is not None: - pbar_data_dict, train_stat, self.stop_fn_flag = self.train_step() + train_stat, self.stop_fn_flag = self.train_step() + pbar_data_dict = { + "env_step": str(self.env_step), + "rew": f"{self.last_rew:.2f}", + "len": str(int(self.last_len)), + "n/ep": str(train_stat.n_collected_episodes), + "n/st": str(train_stat.n_collected_steps), + } t.update(train_stat.n_collected_steps) if self.stop_fn_flag: t.set_postfix(**pbar_data_dict) @@ -322,13 +328,12 @@ class BaseTrainer(ABC): assert self.buffer, "No train_collector or buffer specified" train_stat = CollectStatsBase( n_collected_episodes=len(self.buffer), - n_collected_steps=int(self._gradient_step), ) t.update() update_stat = self.policy_update_fn(train_stat) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) - pbar_data_dict["gradient_step"] = self._gradient_step + pbar_data_dict["gradient_step"] = str(self._gradient_step) t.set_postfix(**pbar_data_dict) @@ -413,11 +418,19 @@ class BaseTrainer(ABC): return test_stat, stop_fn_flag - def train_step(self) -> tuple[dict[str, Any], CollectStats, bool]: - """Perform one training step.""" + def train_step(self) -> tuple[CollectStats, bool]: + """Perform one training step. + + If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. + Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return + on it. + Finally, if the latter is also True, will set should_stop_training to True. + + :return: A tuple of the training stats and a boolean indicating whether to stop training. + """ assert self.episode_per_test is not None assert self.train_collector is not None - stop_fn_flag = False + should_stop_training = False if self.train_fn: self.train_fn(self.epoch, self.env_step) result = self.train_collector.collect( @@ -439,13 +452,6 @@ class BaseTrainer(ABC): self.logger.log_train_data(asdict(result), self.env_step) - data = { - "env_step": str(self.env_step), - "rew": f"{self.last_rew:.2f}", - "len": str(int(self.last_len)), - "n/ep": str(result.n_collected_episodes), - "n/st": str(result.n_collected_steps), - } if ( result.n_collected_episodes > 0 and self.test_in_train @@ -464,12 +470,12 @@ class BaseTrainer(ABC): ) assert test_result.returns_stat is not None # for mypy if self.stop_fn(test_result.returns_stat.mean): - stop_fn_flag = True + should_stop_training = True self.best_reward = test_result.returns_stat.mean self.best_reward_std = test_result.returns_stat.std else: self.policy.train() - return data, result, stop_fn_flag + return result, should_stop_training # TODO: move moving average computation and logging into its own logger # TODO: maybe think about a command line logger instead of always printing data dict