Bugfix: allow for training_stat to be None instead of asserting not-None

This commit is contained in:
Michael Panchenko 2024-05-05 22:08:22 +02:00
parent 4e38aeb829
commit a8e9df31f7
2 changed files with 24 additions and 10 deletions

View File

@ -82,7 +82,8 @@ class EpochStats(DataclassPPrintMixin):
"""The statistics of the last call to the training collector.""" """The statistics of the last call to the training collector."""
test_collect_stat: Optional["CollectStats"] test_collect_stat: Optional["CollectStats"]
"""The statistics of the last call to the test collector.""" """The statistics of the last call to the test collector."""
training_stat: "TrainingStats" training_stat: Optional["TrainingStats"]
"""The statistics of the last model update step.""" """The statistics of the last model update step.
Can be None if no model update is performed, typically in the last training iteration."""
info_stat: InfoStats info_stat: InfoStats
"""The information of the collector.""" """The information of the collector."""

View File

@ -360,9 +360,6 @@ class BaseTrainer(ABC):
self.logger.log_info_data(asdict(info_stat), self.epoch) self.logger.log_info_data(asdict(info_stat), self.epoch)
# in case trainer is used with run(), epoch_stat will not be returned # in case trainer is used with run(), epoch_stat will not be returned
assert (
update_stat is not None
), "Defined in the loop above, this shouldn't have happened and is likely a bug!"
return EpochStats( return EpochStats(
epoch=self.epoch, epoch=self.epoch,
train_collect_stat=train_stat, train_collect_stat=train_stat,
@ -417,13 +414,23 @@ class BaseTrainer(ABC):
self.policy.is_within_training_step = old_value self.policy.is_within_training_step = old_value
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
"""Perform one training iteration.
A training iteration includes collecting data (for online RL), determining whether to stop training,
and peforming a policy update if the training iteration should continue.
:return: the iteration's collect stats, training stats, and a flag indicating whether to stop training.
If training is to be stopped, no gradient steps will be performed and the training stats will be `None`.
"""
with self._is_within_training_step_enabled(True): with self._is_within_training_step_enabled(True):
should_stop_training = False should_stop_training = False
collect_stats: CollectStatsBase | CollectStats collect_stats: CollectStatsBase | CollectStats
if self.train_collector is not None: if self.train_collector is not None:
collect_stats = self._collect_training_data() collect_stats = self._collect_training_data()
should_stop_training = self._test_in_train(collect_stats) should_stop_training = self._update_best_reward_and_return_should_stop_training(
collect_stats,
)
else: else:
assert self.buffer is not None, "Either train_collector or buffer must be provided." assert self.buffer is not None, "Either train_collector or buffer must be provided."
collect_stats = CollectStatsBase( collect_stats = CollectStatsBase(
@ -467,11 +474,17 @@ class BaseTrainer(ABC):
return collect_stats return collect_stats
def _test_in_train(self, collect_stats: CollectStats) -> bool: def _update_best_reward_and_return_should_stop_training(
"""If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. self,
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return collect_stats: CollectStats,
) -> bool:
"""If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data.
Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return
on it. on it.
Finally, if the latter is also True, will set should_stop_training to True. Finally, if the latter is also True, will return True.
**NOTE:** has a side effect of updating the best reward and corresponding std.
:param collect_stats: the data collection stats :param collect_stats: the data collection stats
:return: flag indicating whether to stop training :return: flag indicating whether to stop training