Bugfix: allow for training_stat to be None instead of asserting not-None
This commit is contained in:
parent
4e38aeb829
commit
a8e9df31f7
@ -82,7 +82,8 @@ class EpochStats(DataclassPPrintMixin):
|
||||
"""The statistics of the last call to the training collector."""
|
||||
test_collect_stat: Optional["CollectStats"]
|
||||
"""The statistics of the last call to the test collector."""
|
||||
training_stat: "TrainingStats"
|
||||
"""The statistics of the last model update step."""
|
||||
training_stat: Optional["TrainingStats"]
|
||||
"""The statistics of the last model update step.
|
||||
Can be None if no model update is performed, typically in the last training iteration."""
|
||||
info_stat: InfoStats
|
||||
"""The information of the collector."""
|
||||
|
@ -360,9 +360,6 @@ class BaseTrainer(ABC):
|
||||
self.logger.log_info_data(asdict(info_stat), self.epoch)
|
||||
|
||||
# in case trainer is used with run(), epoch_stat will not be returned
|
||||
assert (
|
||||
update_stat is not None
|
||||
), "Defined in the loop above, this shouldn't have happened and is likely a bug!"
|
||||
return EpochStats(
|
||||
epoch=self.epoch,
|
||||
train_collect_stat=train_stat,
|
||||
@ -417,13 +414,23 @@ class BaseTrainer(ABC):
|
||||
self.policy.is_within_training_step = old_value
|
||||
|
||||
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
|
||||
"""Perform one training iteration.
|
||||
|
||||
A training iteration includes collecting data (for online RL), determining whether to stop training,
|
||||
and peforming a policy update if the training iteration should continue.
|
||||
|
||||
:return: the iteration's collect stats, training stats, and a flag indicating whether to stop training.
|
||||
If training is to be stopped, no gradient steps will be performed and the training stats will be `None`.
|
||||
"""
|
||||
with self._is_within_training_step_enabled(True):
|
||||
should_stop_training = False
|
||||
|
||||
collect_stats: CollectStatsBase | CollectStats
|
||||
if self.train_collector is not None:
|
||||
collect_stats = self._collect_training_data()
|
||||
should_stop_training = self._test_in_train(collect_stats)
|
||||
should_stop_training = self._update_best_reward_and_return_should_stop_training(
|
||||
collect_stats,
|
||||
)
|
||||
else:
|
||||
assert self.buffer is not None, "Either train_collector or buffer must be provided."
|
||||
collect_stats = CollectStatsBase(
|
||||
@ -467,11 +474,17 @@ class BaseTrainer(ABC):
|
||||
|
||||
return collect_stats
|
||||
|
||||
def _test_in_train(self, collect_stats: CollectStats) -> bool:
|
||||
"""If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
|
||||
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
|
||||
def _update_best_reward_and_return_should_stop_training(
|
||||
self,
|
||||
collect_stats: CollectStats,
|
||||
) -> bool:
|
||||
"""If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data.
|
||||
Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return
|
||||
on it.
|
||||
Finally, if the latter is also True, will set should_stop_training to True.
|
||||
Finally, if the latter is also True, will return True.
|
||||
|
||||
**NOTE:** has a side effect of updating the best reward and corresponding std.
|
||||
|
||||
|
||||
:param collect_stats: the data collection stats
|
||||
:return: flag indicating whether to stop training
|
||||
|
Loading…
x
Reference in New Issue
Block a user