Apply centrally defined callbacks

This commit is contained in:
Dominik Jain 2023-10-16 18:51:30 +02:00
parent ae4850692f
commit d84e936430

View File

@ -4,6 +4,10 @@ import os
from jsonargparse import CLI from jsonargparse import CLI
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import ( from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures, IntermediateModuleFactoryAtariDQNFeatures,
@ -18,12 +22,6 @@ from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import ( from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity, PolicyWrapperFactoryIntrinsicCuriosity,
) )
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainingContext,
)
from tianshou.policy import DQNPolicy
from tianshou.utils import logging from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag from tianshou.utils.logging import datetime_tag
@ -78,24 +76,6 @@ def main(
scale=scale_obs, scale=scale_obs,
) )
class TrainEpochCallback(TrainerEpochCallbackTrain):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
else:
eps = eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
class TestEpochCallback(TrainerEpochCallbackTest):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
policy.set_eps(eps_test)
builder = ( builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config) DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_dqn_params( .with_dqn_params(
@ -107,8 +87,10 @@ def main(
), ),
) )
.with_model_factory(IntermediateModuleFactoryAtariDQN()) .with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train(TrainEpochCallback()) .with_trainer_epoch_callback_train(
.with_trainer_epoch_callback_test(TestEpochCallback()) TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task)) .with_trainer_stop_callback(AtariStopCallback(task))
) )
if icm_lr_scale > 0: if icm_lr_scale > 0: