Bugfix: allow for training_stat to be None instead of asserting not-None
This commit is contained in:
parent
4e38aeb829
commit
a8e9df31f7
@ -82,7 +82,8 @@ class EpochStats(DataclassPPrintMixin):
|
|||||||
"""The statistics of the last call to the training collector."""
|
"""The statistics of the last call to the training collector."""
|
||||||
test_collect_stat: Optional["CollectStats"]
|
test_collect_stat: Optional["CollectStats"]
|
||||||
"""The statistics of the last call to the test collector."""
|
"""The statistics of the last call to the test collector."""
|
||||||
training_stat: "TrainingStats"
|
training_stat: Optional["TrainingStats"]
|
||||||
"""The statistics of the last model update step."""
|
"""The statistics of the last model update step.
|
||||||
|
Can be None if no model update is performed, typically in the last training iteration."""
|
||||||
info_stat: InfoStats
|
info_stat: InfoStats
|
||||||
"""The information of the collector."""
|
"""The information of the collector."""
|
||||||
|
@ -360,9 +360,6 @@ class BaseTrainer(ABC):
|
|||||||
self.logger.log_info_data(asdict(info_stat), self.epoch)
|
self.logger.log_info_data(asdict(info_stat), self.epoch)
|
||||||
|
|
||||||
# in case trainer is used with run(), epoch_stat will not be returned
|
# in case trainer is used with run(), epoch_stat will not be returned
|
||||||
assert (
|
|
||||||
update_stat is not None
|
|
||||||
), "Defined in the loop above, this shouldn't have happened and is likely a bug!"
|
|
||||||
return EpochStats(
|
return EpochStats(
|
||||||
epoch=self.epoch,
|
epoch=self.epoch,
|
||||||
train_collect_stat=train_stat,
|
train_collect_stat=train_stat,
|
||||||
@ -417,13 +414,23 @@ class BaseTrainer(ABC):
|
|||||||
self.policy.is_within_training_step = old_value
|
self.policy.is_within_training_step = old_value
|
||||||
|
|
||||||
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
|
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
|
||||||
|
"""Perform one training iteration.
|
||||||
|
|
||||||
|
A training iteration includes collecting data (for online RL), determining whether to stop training,
|
||||||
|
and peforming a policy update if the training iteration should continue.
|
||||||
|
|
||||||
|
:return: the iteration's collect stats, training stats, and a flag indicating whether to stop training.
|
||||||
|
If training is to be stopped, no gradient steps will be performed and the training stats will be `None`.
|
||||||
|
"""
|
||||||
with self._is_within_training_step_enabled(True):
|
with self._is_within_training_step_enabled(True):
|
||||||
should_stop_training = False
|
should_stop_training = False
|
||||||
|
|
||||||
collect_stats: CollectStatsBase | CollectStats
|
collect_stats: CollectStatsBase | CollectStats
|
||||||
if self.train_collector is not None:
|
if self.train_collector is not None:
|
||||||
collect_stats = self._collect_training_data()
|
collect_stats = self._collect_training_data()
|
||||||
should_stop_training = self._test_in_train(collect_stats)
|
should_stop_training = self._update_best_reward_and_return_should_stop_training(
|
||||||
|
collect_stats,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert self.buffer is not None, "Either train_collector or buffer must be provided."
|
assert self.buffer is not None, "Either train_collector or buffer must be provided."
|
||||||
collect_stats = CollectStatsBase(
|
collect_stats = CollectStatsBase(
|
||||||
@ -467,11 +474,17 @@ class BaseTrainer(ABC):
|
|||||||
|
|
||||||
return collect_stats
|
return collect_stats
|
||||||
|
|
||||||
def _test_in_train(self, collect_stats: CollectStats) -> bool:
|
def _update_best_reward_and_return_should_stop_training(
|
||||||
"""If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
|
self,
|
||||||
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
|
collect_stats: CollectStats,
|
||||||
|
) -> bool:
|
||||||
|
"""If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data.
|
||||||
|
Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return
|
||||||
on it.
|
on it.
|
||||||
Finally, if the latter is also True, will set should_stop_training to True.
|
Finally, if the latter is also True, will return True.
|
||||||
|
|
||||||
|
**NOTE:** has a side effect of updating the best reward and corresponding std.
|
||||||
|
|
||||||
|
|
||||||
:param collect_stats: the data collection stats
|
:param collect_stats: the data collection stats
|
||||||
:return: flag indicating whether to stop training
|
:return: flag indicating whether to stop training
|
||||||
|
Loading…
x
Reference in New Issue
Block a user