Apply centrally defined callbacks
This commit is contained in:
parent
ae4850692f
commit
d84e936430
@ -4,6 +4,10 @@ import os
|
||||
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_callbacks import (
|
||||
TestEpochCallbackDQNSetEps,
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
||||
)
|
||||
from examples.atari.atari_network import (
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
@ -18,12 +22,6 @@ from tianshou.highlevel.params.policy_params import DQNParams
|
||||
from tianshou.highlevel.params.policy_wrapper import (
|
||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||
)
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTest,
|
||||
TrainerEpochCallbackTrain,
|
||||
TrainingContext,
|
||||
)
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
@ -78,24 +76,6 @@ def main(
|
||||
scale=scale_obs,
|
||||
)
|
||||
|
||||
class TrainEpochCallback(TrainerEpochCallbackTrain):
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
logger = context.logger
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
|
||||
else:
|
||||
eps = eps_train_final
|
||||
policy.set_eps(eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
class TestEpochCallback(TrainerEpochCallbackTest):
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
policy.set_eps(eps_test)
|
||||
|
||||
builder = (
|
||||
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_dqn_params(
|
||||
@ -107,8 +87,10 @@ def main(
|
||||
),
|
||||
)
|
||||
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
||||
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
||||
.with_trainer_epoch_callback_test(TestEpochCallback())
|
||||
.with_trainer_epoch_callback_train(
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
)
|
||||
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user