From ca4dad113980e96c8632b1cce8153f874c8950e4 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 2 May 2024 18:06:01 +0200 Subject: [PATCH] BaseTrainer: Refactoring New method training_step, which * collects training data (method _collect_training_data) * performs "test in train" (method _test_in_train) * performs policy update The old method named train_step performed only the first two points and was now split into two separate methods --- tianshou/trainer/base.py | 102 ++++++++++++++++++++++++--------------- 1 file changed, 63 insertions(+), 39 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index f657f63..825c80c 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from collections import defaultdict, deque from collections.abc import Callable from dataclasses import asdict +from typing import Optional, Tuple import numpy as np import tqdm @@ -303,8 +304,10 @@ class BaseTrainer(ABC): with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: train_stat: CollectStatsBase while t.n < t.total and not self.stop_fn_flag: - if self.train_collector is not None: - train_stat, self.stop_fn_flag = self.train_step() + + train_stat, update_stat, self.stop_fn_flag = self.training_step() + + if isinstance(train_stat, CollectStats): pbar_data_dict = { "env_step": str(self.env_step), "rew": f"{self.last_rew:.2f}", @@ -313,23 +316,17 @@ class BaseTrainer(ABC): "n/st": str(train_stat.n_collected_steps), } t.update(train_stat.n_collected_steps) - if self.stop_fn_flag: - t.set_postfix(**pbar_data_dict) - break else: pbar_data_dict = {} - assert self.buffer, "No train_collector or buffer specified" - train_stat = CollectStatsBase( - n_collected_episodes=len(self.buffer), - ) t.update() - update_stat = self.policy_update_fn(train_stat) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) pbar_data_dict["gradient_step"] = str(self._gradient_step) - t.set_postfix(**pbar_data_dict) + if self.stop_fn_flag: + break + if t.n <= t.total and not self.stop_fn_flag: t.update() @@ -410,45 +407,71 @@ class BaseTrainer(ABC): return test_stat, stop_fn_flag - def train_step(self) -> tuple[CollectStats, bool]: - """Perform one training step. + def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]: + should_stop_training = False + if self.train_collector is not None: + collect_stats = self._collect_training_data() + should_stop_training = self._test_in_train(collect_stats) + else: + collect_stats = CollectStatsBase( + n_collected_episodes=len(self.buffer), + ) + + if not should_stop_training: + training_stats = self.policy_update_fn(collect_stats) + else: + training_stats = None + + return collect_stats, training_stats, should_stop_training + + def _collect_training_data(self) -> CollectStats: + """Performs training data collection + + :return: the data collection stats + """ + assert self.episode_per_test is not None + assert self.train_collector is not None + if self.train_fn: + self.train_fn(self.epoch, self.env_step) + collect_stats = self.train_collector.collect( + n_step=self.step_per_collect, + n_episode=self.episode_per_collect, + ) + + self.env_step += collect_stats.n_collected_steps + + if collect_stats.n_collected_episodes > 0: + assert collect_stats.returns_stat is not None # for mypy + assert collect_stats.lens_stat is not None # for mypy + self.last_rew = collect_stats.returns_stat.mean + self.last_len = collect_stats.lens_stat.mean + if self.reward_metric: # TODO: move inside collector + rew = self.reward_metric(collect_stats.returns) + collect_stats.returns = rew + collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) + + self.logger.log_train_data(asdict(collect_stats), self.env_step) + + return collect_stats + + def _test_in_train(self, collect_stats: CollectStats) -> bool: + """ If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return on it. Finally, if the latter is also True, will set should_stop_training to True. - :return: A tuple of the training stats and a boolean indicating whether to stop training. + :param collect_stats: the data collection stats + :return: flag indicating whether to stop training """ - assert self.episode_per_test is not None - assert self.train_collector is not None should_stop_training = False - if self.train_fn: - self.train_fn(self.epoch, self.env_step) - result = self.train_collector.collect( - n_step=self.step_per_collect, - n_episode=self.episode_per_collect, - ) - - self.env_step += result.n_collected_steps - - if result.n_collected_episodes > 0: - assert result.returns_stat is not None # for mypy - assert result.lens_stat is not None # for mypy - self.last_rew = result.returns_stat.mean - self.last_len = result.lens_stat.mean - if self.reward_metric: # TODO: move inside collector - rew = self.reward_metric(result.returns) - result.returns = rew - result.returns_stat = SequenceSummaryStats.from_sequence(rew) - - self.logger.log_train_data(asdict(result), self.env_step) if ( - result.n_collected_episodes > 0 + collect_stats.n_collected_episodes > 0 and self.test_in_train and self.stop_fn - and self.stop_fn(result.returns_stat.mean) # type: ignore + and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore ): assert self.test_collector is not None test_result = test_episode( @@ -464,7 +487,8 @@ class BaseTrainer(ABC): should_stop_training = True self.best_reward = test_result.returns_stat.mean self.best_reward_std = test_result.returns_stat.std - return result, should_stop_training + + return should_stop_training # TODO: move moving average computation and logging into its own logger # TODO: maybe think about a command line logger instead of always printing data dict