diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index fc0483f..b4ae0dd 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -4,6 +4,10 @@ import os from jsonargparse import CLI +from examples.atari.atari_callbacks import ( + TestEpochCallbackDQNSetEps, + TrainEpochCallbackNatureDQNEpsLinearDecay, +) from examples.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, @@ -18,12 +22,6 @@ from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.params.policy_wrapper import ( PolicyWrapperFactoryIntrinsicCuriosity, ) -from tianshou.highlevel.trainer import ( - TrainerEpochCallbackTest, - TrainerEpochCallbackTrain, - TrainingContext, -) -from tianshou.policy import DQNPolicy from tianshou.utils import logging from tianshou.utils.logging import datetime_tag @@ -78,24 +76,6 @@ def main( scale=scale_obs, ) - class TrainEpochCallback(TrainerEpochCallbackTrain): - def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy: DQNPolicy = context.policy - logger = context.logger - # nature DQN setting, linear decay in the first 1M steps - if env_step <= 1e6: - eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) - else: - eps = eps_train_final - policy.set_eps(eps) - if env_step % 1000 == 0: - logger.write("train/env_step", env_step, {"train/eps": eps}) - - class TestEpochCallback(TrainerEpochCallbackTest): - def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy: DQNPolicy = context.policy - policy.set_eps(eps_test) - builder = ( DQNExperimentBuilder(env_factory, experiment_config, sampling_config) .with_dqn_params( @@ -107,8 +87,10 @@ def main( ), ) .with_model_factory(IntermediateModuleFactoryAtariDQN()) - .with_trainer_epoch_callback_train(TrainEpochCallback()) - .with_trainer_epoch_callback_test(TestEpochCallback()) + .with_trainer_epoch_callback_train( + TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final), + ) + .with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test)) .with_trainer_stop_callback(AtariStopCallback(task)) ) if icm_lr_scale > 0: