diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 1253529..830ade1 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -6,7 +6,7 @@ from examples.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback +from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( DQNExperimentBuilder, @@ -17,8 +17,8 @@ from tianshou.highlevel.params.policy_wrapper import ( PolicyWrapperFactoryIntrinsicCuriosity, ) from tianshou.highlevel.trainer import ( - TrainerEpochCallbackTestDQNSetEps, - TrainerEpochCallbackTrainDQNEpsLinearDecay, + EpochTestCallbackDQNSetEps, + EpochTrainCallbackDQNEpsLinearDecay, ) from tianshou.utils import logging from tianshou.utils.logging import datetime_tag @@ -79,11 +79,11 @@ def main( ), ) .with_model_factory(IntermediateModuleFactoryAtariDQN()) - .with_trainer_epoch_callback_train( - TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final), + .with_epoch_train_callback( + EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final), ) - .with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test)) - .with_trainer_stop_callback(AtariStopCallback(task)) + .with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test)) + .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: builder.with_policy_wrapper_factory( diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index fd160df..412ef1d 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -6,7 +6,7 @@ from collections.abc import Sequence from examples.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback +from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, @@ -14,8 +14,8 @@ from tianshou.highlevel.experiment import ( ) from tianshou.highlevel.params.policy_params import IQNParams from tianshou.highlevel.trainer import ( - TrainerEpochCallbackTestDQNSetEps, - TrainerEpochCallbackTrainDQNEpsLinearDecay, + EpochTestCallbackDQNSetEps, + EpochTrainCallbackDQNEpsLinearDecay, ) from tianshou.utils import logging from tianshou.utils.logging import datetime_tag @@ -83,11 +83,11 @@ def main( ), ) .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True)) - .with_trainer_epoch_callback_train( - TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final), + .with_epoch_train_callback( + EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final), ) - .with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test)) - .with_trainer_stop_callback(AtariStopCallback(task)) + .with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test)) + .with_epoch_stop_callback(AtariEpochStopCallback(task)) .build() ) experiment.run(log_name) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 0a05b3f..2dafb59 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -7,7 +7,7 @@ from examples.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback +from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, @@ -95,7 +95,7 @@ def main( ) .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True)) .with_critic_factory_use_actor() - .with_trainer_stop_callback(AtariStopCallback(task)) + .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: builder.with_policy_wrapper_factory( diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index f1fd8c4..a8a5bd4 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -6,7 +6,7 @@ from examples.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback +from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( DiscreteSACExperimentBuilder, @@ -82,7 +82,7 @@ def main( ) .with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True)) .with_common_critic_factory_use_actor() - .with_trainer_stop_callback(AtariStopCallback(task)) + .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: builder.with_policy_wrapper_factory( diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 1e94068..f1b3120 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -15,7 +15,7 @@ from tianshou.highlevel.env import ( EnvPoolFactory, VectorEnvType, ) -from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext +from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext try: import envpool @@ -387,7 +387,7 @@ class AtariEnvFactory(EnvFactoryGymnasium): return kwargs -class AtariStopCallback(TrainerStopCallback): +class AtariEpochStopCallback(EpochStopCallback): def __init__(self, task: str): self.task = task diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 727bfb2..b72ab5e 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -163,17 +163,19 @@ class OnPolicyAgentFactory(AgentFactory, ABC): callbacks = self.trainer_callbacks context = TrainingContext(world.policy, world.envs, world.logger) train_fn = ( - callbacks.epoch_callback_train.get_trainer_fn(context) - if callbacks.epoch_callback_train + callbacks.epoch_train_callback.get_trainer_fn(context) + if callbacks.epoch_train_callback else None ) test_fn = ( - callbacks.epoch_callback_test.get_trainer_fn(context) - if callbacks.epoch_callback_test + callbacks.epoch_test_callback.get_trainer_fn(context) + if callbacks.epoch_test_callback else None ) stop_fn = ( - callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None + callbacks.epoch_stop_callback.get_trainer_fn(context) + if callbacks.epoch_stop_callback + else None ) return OnpolicyTrainer( policy=world.policy, @@ -205,17 +207,19 @@ class OffPolicyAgentFactory(AgentFactory, ABC): callbacks = self.trainer_callbacks context = TrainingContext(world.policy, world.envs, world.logger) train_fn = ( - callbacks.epoch_callback_train.get_trainer_fn(context) - if callbacks.epoch_callback_train + callbacks.epoch_train_callback.get_trainer_fn(context) + if callbacks.epoch_train_callback else None ) test_fn = ( - callbacks.epoch_callback_test.get_trainer_fn(context) - if callbacks.epoch_callback_test + callbacks.epoch_test_callback.get_trainer_fn(context) + if callbacks.epoch_test_callback else None ) stop_fn = ( - callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None + callbacks.epoch_stop_callback.get_trainer_fn(context) + if callbacks.epoch_stop_callback + else None ) return OffpolicyTrainer( policy=world.policy, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index bac469f..6fa2853 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -70,10 +70,10 @@ from tianshou.highlevel.persistence import ( PolicyPersistence, ) from tianshou.highlevel.trainer import ( + EpochStopCallback, + EpochTestCallback, + EpochTrainCallback, TrainerCallbacks, - TrainerEpochCallbackTest, - TrainerEpochCallbackTrain, - TrainerStopCallback, ) from tianshou.highlevel.world import World from tianshou.policy import BasePolicy @@ -383,25 +383,25 @@ class ExperimentBuilder: self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) return self - def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self: + def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self: """Allows to define a callback function which is called at the beginning of every epoch during training. :param callback: the callback :return: the builder """ - self._trainer_callbacks.epoch_callback_train = callback + self._trainer_callbacks.epoch_train_callback = callback return self - def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self: + def with_epoch_test_callback(self, callback: EpochTestCallback) -> Self: """Allows to define a callback function which is called at the beginning of testing in each epoch. :param callback: the callback :return: the builder """ - self._trainer_callbacks.epoch_callback_test = callback + self._trainer_callbacks.epoch_test_callback = callback return self - def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self: + def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self: """Allows to define a callback that decides whether training shall stop early. The callback receives the undiscounted returns of the testing result. @@ -409,7 +409,7 @@ class ExperimentBuilder: :param callback: the callback :return: the builder """ - self._trainer_callbacks.stop_callback = callback + self._trainer_callbacks.epoch_stop_callback = callback return self @abstractmethod diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index da22a32..4eccc6a 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -1,3 +1,4 @@ +import logging from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass @@ -9,6 +10,7 @@ from tianshou.policy import BasePolicy, DQNPolicy from tianshou.utils.string import ToStringMixin TPolicy = TypeVar("TPolicy", bound=BasePolicy) +log = logging.getLogger(__name__) class TrainingContext: @@ -18,8 +20,10 @@ class TrainingContext: self.logger = logger -class TrainerEpochCallbackTrain(ToStringMixin, ABC): - """Callback which is called at the beginning of each epoch.""" +class EpochTrainCallback(ToStringMixin, ABC): + """Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase + of each epoch. + """ @abstractmethod def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: @@ -32,8 +36,8 @@ class TrainerEpochCallbackTrain(ToStringMixin, ABC): return fn -class TrainerEpochCallbackTest(ToStringMixin, ABC): - """Callback which is called at the beginning of each epoch.""" +class EpochTestCallback(ToStringMixin, ABC): + """Callback which is called at the beginning of the test phase of each epoch.""" @abstractmethod def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: @@ -46,8 +50,10 @@ class TrainerEpochCallbackTest(ToStringMixin, ABC): return fn -class TrainerStopCallback(ToStringMixin, ABC): - """Callback indicating whether training should stop.""" +class EpochStopCallback(ToStringMixin, ABC): + """Callback which is called after the test phase of each epoch in order to determine + whether training should stop early. + """ @abstractmethod def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: @@ -69,12 +75,12 @@ class TrainerStopCallback(ToStringMixin, ABC): class TrainerCallbacks: """Container for callbacks used during training.""" - epoch_callback_train: TrainerEpochCallbackTrain | None = None - epoch_callback_test: TrainerEpochCallbackTest | None = None - stop_callback: TrainerStopCallback | None = None + epoch_train_callback: EpochTrainCallback | None = None + epoch_test_callback: EpochTestCallback | None = None + epoch_stop_callback: EpochStopCallback | None = None -class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain): +class EpochTrainCallbackDQNSetEps(EpochTrainCallback): """Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch. """ @@ -87,7 +93,7 @@ class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain): policy.set_eps(self.eps_test) -class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain): +class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback): """Sets the epsilon value for DQN-based policies at the beginning of the training stage in each epoch, using a linear decay in the first `decay_steps` steps. """ @@ -110,7 +116,7 @@ class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain): logger.write("train/env_step", env_step, {"train/eps": eps}) -class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest): +class EpochTestCallbackDQNSetEps(EpochTestCallback): """Sets the epsilon value for DQN-based policies at the beginning of the test stage in each epoch. """ @@ -121,3 +127,25 @@ class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest): def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: policy = cast(DQNPolicy, context.policy) policy.set_eps(self.eps_test) + + +class EpochStopCallbackRewardThreshold(EpochStopCallback): + """Stops training once the mean rewards exceed the given reward threshold or the threshold that + is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`). + """ + + def __init__(self, threshold: float | None = None): + """:param threshold: the reward threshold beyond which to stop training. + If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`. + """ + self.threshold = threshold + + def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: + threshold = self.threshold + if threshold is None: + threshold = context.envs.env.spec.reward_threshold # type: ignore + assert threshold is not None + is_reached = mean_rewards >= threshold + if is_reached: + log.info(f"Reward threshold ({threshold}) exceeded") + return is_reached