Apply centrally defined callbacks
This commit is contained in:
parent
ae4850692f
commit
d84e936430
@ -4,6 +4,10 @@ import os
|
|||||||
|
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.atari.atari_callbacks import (
|
||||||
|
TestEpochCallbackDQNSetEps,
|
||||||
|
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
||||||
|
)
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
IntermediateModuleFactoryAtariDQN,
|
IntermediateModuleFactoryAtariDQN,
|
||||||
IntermediateModuleFactoryAtariDQNFeatures,
|
IntermediateModuleFactoryAtariDQNFeatures,
|
||||||
@ -18,12 +22,6 @@ from tianshou.highlevel.params.policy_params import DQNParams
|
|||||||
from tianshou.highlevel.params.policy_wrapper import (
|
from tianshou.highlevel.params.policy_wrapper import (
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.trainer import (
|
|
||||||
TrainerEpochCallbackTest,
|
|
||||||
TrainerEpochCallbackTrain,
|
|
||||||
TrainingContext,
|
|
||||||
)
|
|
||||||
from tianshou.policy import DQNPolicy
|
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
from tianshou.utils.logging import datetime_tag
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
@ -78,24 +76,6 @@ def main(
|
|||||||
scale=scale_obs,
|
scale=scale_obs,
|
||||||
)
|
)
|
||||||
|
|
||||||
class TrainEpochCallback(TrainerEpochCallbackTrain):
|
|
||||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
||||||
policy: DQNPolicy = context.policy
|
|
||||||
logger = context.logger
|
|
||||||
# nature DQN setting, linear decay in the first 1M steps
|
|
||||||
if env_step <= 1e6:
|
|
||||||
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
|
|
||||||
else:
|
|
||||||
eps = eps_train_final
|
|
||||||
policy.set_eps(eps)
|
|
||||||
if env_step % 1000 == 0:
|
|
||||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
|
||||||
|
|
||||||
class TestEpochCallback(TrainerEpochCallbackTest):
|
|
||||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
||||||
policy: DQNPolicy = context.policy
|
|
||||||
policy.set_eps(eps_test)
|
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
.with_dqn_params(
|
.with_dqn_params(
|
||||||
@ -107,8 +87,10 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
||||||
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
.with_trainer_epoch_callback_train(
|
||||||
.with_trainer_epoch_callback_test(TestEpochCallback())
|
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||||
|
)
|
||||||
|
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
)
|
)
|
||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user