Bugfix: allow for training_stat to be None instead of asserting not-None

This commit is contained in:
Michael Panchenko 2024-05-05 22:08:22 +02:00
parent 4e38aeb829
commit a8e9df31f7
2 changed files with 24 additions and 10 deletions

View File

@ -82,7 +82,8 @@ class EpochStats(DataclassPPrintMixin):
"""The statistics of the last call to the training collector."""
test_collect_stat: Optional["CollectStats"]
"""The statistics of the last call to the test collector."""
training_stat: "TrainingStats"
"""The statistics of the last model update step."""
training_stat: Optional["TrainingStats"]
"""The statistics of the last model update step.
Can be None if no model update is performed, typically in the last training iteration."""
info_stat: InfoStats
"""The information of the collector."""

View File

@ -360,9 +360,6 @@ class BaseTrainer(ABC):
self.logger.log_info_data(asdict(info_stat), self.epoch)
# in case trainer is used with run(), epoch_stat will not be returned
assert (
update_stat is not None
), "Defined in the loop above, this shouldn't have happened and is likely a bug!"
return EpochStats(
epoch=self.epoch,
train_collect_stat=train_stat,
@ -417,13 +414,23 @@ class BaseTrainer(ABC):
self.policy.is_within_training_step = old_value
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
"""Perform one training iteration.
A training iteration includes collecting data (for online RL), determining whether to stop training,
and peforming a policy update if the training iteration should continue.
:return: the iteration's collect stats, training stats, and a flag indicating whether to stop training.
If training is to be stopped, no gradient steps will be performed and the training stats will be `None`.
"""
with self._is_within_training_step_enabled(True):
should_stop_training = False
collect_stats: CollectStatsBase | CollectStats
if self.train_collector is not None:
collect_stats = self._collect_training_data()
should_stop_training = self._test_in_train(collect_stats)
should_stop_training = self._update_best_reward_and_return_should_stop_training(
collect_stats,
)
else:
assert self.buffer is not None, "Either train_collector or buffer must be provided."
collect_stats = CollectStatsBase(
@ -467,11 +474,17 @@ class BaseTrainer(ABC):
return collect_stats
def _test_in_train(self, collect_stats: CollectStats) -> bool:
"""If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
def _update_best_reward_and_return_should_stop_training(
self,
collect_stats: CollectStats,
) -> bool:
"""If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data.
Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return
on it.
Finally, if the latter is also True, will set should_stop_training to True.
Finally, if the latter is also True, will return True.
**NOTE:** has a side effect of updating the best reward and corresponding std.
:param collect_stats: the data collection stats
:return: flag indicating whether to stop training