diff --git a/examples/atari/atari_callbacks.py b/examples/atari/atari_callbacks.py deleted file mode 100644 index d0b4315..0000000 --- a/examples/atari/atari_callbacks.py +++ /dev/null @@ -1,33 +0,0 @@ -from tianshou.highlevel.trainer import ( - TrainerEpochCallbackTest, - TrainerEpochCallbackTrain, - TrainingContext, -) -from tianshou.policy import DQNPolicy - - -class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest): - def __init__(self, eps_test: float): - self.eps_test = eps_test - - def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy: DQNPolicy = context.policy - policy.set_eps(self.eps_test) - - -class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain): - def __init__(self, eps_train: float, eps_train_final: float): - self.eps_train = eps_train - self.eps_train_final = eps_train_final - - def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy: DQNPolicy = context.policy - logger = context.logger - # nature DQN setting, linear decay in the first 1M steps - if env_step <= 1e6: - eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final) - else: - eps = self.eps_train_final - policy.set_eps(eps) - if env_step % 1000 == 0: - logger.write("train/env_step", env_step, {"train/eps": eps}) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 0c435ee..1253529 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -2,10 +2,6 @@ import os -from examples.atari.atari_callbacks import ( - TestEpochCallbackDQNSetEps, - TrainEpochCallbackNatureDQNEpsLinearDecay, -) from examples.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, @@ -20,6 +16,10 @@ from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.params.policy_wrapper import ( PolicyWrapperFactoryIntrinsicCuriosity, ) +from tianshou.highlevel.trainer import ( + TrainerEpochCallbackTestDQNSetEps, + TrainerEpochCallbackTrainDQNEpsLinearDecay, +) from tianshou.utils import logging from tianshou.utils.logging import datetime_tag @@ -80,9 +80,9 @@ def main( ) .with_model_factory(IntermediateModuleFactoryAtariDQN()) .with_trainer_epoch_callback_train( - TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final), + TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final), ) - .with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test)) + .with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test)) .with_trainer_stop_callback(AtariStopCallback(task)) ) if icm_lr_scale > 0: diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 8f94ece..fd160df 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -3,10 +3,6 @@ import os from collections.abc import Sequence -from examples.atari.atari_callbacks import ( - TestEpochCallbackDQNSetEps, - TrainEpochCallbackNatureDQNEpsLinearDecay, -) from examples.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, ) @@ -17,6 +13,10 @@ from tianshou.highlevel.experiment import ( IQNExperimentBuilder, ) from tianshou.highlevel.params.policy_params import IQNParams +from tianshou.highlevel.trainer import ( + TrainerEpochCallbackTestDQNSetEps, + TrainerEpochCallbackTrainDQNEpsLinearDecay, +) from tianshou.utils import logging from tianshou.utils.logging import datetime_tag @@ -84,9 +84,9 @@ def main( ) .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True)) .with_trainer_epoch_callback_train( - TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final), + TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final), ) - .with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test)) + .with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test)) .with_trainer_stop_callback(AtariStopCallback(task)) .build() ) diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 7f70563..da22a32 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass -from typing import TypeVar +from typing import TypeVar, cast from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger -from tianshou.policy import BasePolicy +from tianshou.policy import BasePolicy, DQNPolicy from tianshou.utils.string import ToStringMixin TPolicy = TypeVar("TPolicy", bound=BasePolicy) @@ -72,3 +72,52 @@ class TrainerCallbacks: epoch_callback_train: TrainerEpochCallbackTrain | None = None epoch_callback_test: TrainerEpochCallbackTest | None = None stop_callback: TrainerStopCallback | None = None + + +class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain): + """Sets the epsilon value for DQN-based policies at the beginning of the training + stage in each epoch. + """ + + def __init__(self, eps_test: float): + self.eps_test = eps_test + + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + policy = cast(DQNPolicy, context.policy) + policy.set_eps(self.eps_test) + + +class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain): + """Sets the epsilon value for DQN-based policies at the beginning of the training + stage in each epoch, using a linear decay in the first `decay_steps` steps. + """ + + def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000): + self.eps_train = eps_train + self.eps_train_final = eps_train_final + self.decay_steps = decay_steps + + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + policy = cast(DQNPolicy, context.policy) + logger = context.logger + if env_step <= self.decay_steps: + eps = self.eps_train - env_step / self.decay_steps * ( + self.eps_train - self.eps_train_final + ) + else: + eps = self.eps_train_final + policy.set_eps(eps) + logger.write("train/env_step", env_step, {"train/eps": eps}) + + +class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest): + """Sets the epsilon value for DQN-based policies at the beginning of the test + stage in each epoch. + """ + + def __init__(self, eps_test: float): + self.eps_test = eps_test + + def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: + policy = cast(DQNPolicy, context.policy) + policy.set_eps(self.eps_test) diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 74e0fa4..e6bc09a 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -7,7 +7,7 @@ from typing import Any import numpy as np -VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray +VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray | float VALID_LOG_VALS = typing.get_args( VALID_LOG_VALS_TYPE, ) # I know it's stupid, but we can't use Union type in isinstance