Apply centrally defined callbacks

This commit is contained in:
Dominik Jain 2023-10-16 18:51:30 +02:00
parent ae4850692f
commit d84e936430

View File

@ -4,6 +4,10 @@ import os
from jsonargparse import CLI
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
@ -18,12 +22,6 @@ from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainingContext,
)
from tianshou.policy import DQNPolicy
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
@ -78,24 +76,6 @@ def main(
scale=scale_obs,
)
class TrainEpochCallback(TrainerEpochCallbackTrain):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
else:
eps = eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
class TestEpochCallback(TrainerEpochCallbackTest):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
policy.set_eps(eps_test)
builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_dqn_params(
@ -107,8 +87,10 @@ def main(
),
)
.with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train(TrainEpochCallback())
.with_trainer_epoch_callback_test(TestEpochCallback())
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
)
if icm_lr_scale > 0: