Move callbacks for setting DQN epsilon values to the library

This commit is contained in:
Dominik Jain 2024-01-11 14:57:03 +01:00
parent 63269fe198
commit ff398beed9
5 changed files with 64 additions and 48 deletions

View File

@ -1,33 +0,0 @@
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainingContext,
)
from tianshou.policy import DQNPolicy
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
def __init__(self, eps_test: float):
self.eps_test = eps_test
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
policy.set_eps(self.eps_test)
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
def __init__(self, eps_train: float, eps_train_final: float):
self.eps_train = eps_train
self.eps_train_final = eps_train_final
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
else:
eps = self.eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

View File

@ -2,10 +2,6 @@
import os
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
@ -20,6 +16,10 @@ from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTestDQNSetEps,
TrainerEpochCallbackTrainDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
@ -80,9 +80,9 @@ def main(
)
.with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
)
if icm_lr_scale > 0:

View File

@ -3,10 +3,6 @@
import os
from collections.abc import Sequence
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
)
@ -17,6 +13,10 @@ from tianshou.highlevel.experiment import (
IQNExperimentBuilder,
)
from tianshou.highlevel.params.policy_params import IQNParams
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTestDQNSetEps,
TrainerEpochCallbackTrainDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
@ -84,9 +84,9 @@ def main(
)
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.build()
)

View File

@ -1,11 +1,11 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
from typing import TypeVar, cast
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy
from tianshou.policy import BasePolicy, DQNPolicy
from tianshou.utils.string import ToStringMixin
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -72,3 +72,52 @@ class TrainerCallbacks:
epoch_callback_train: TrainerEpochCallbackTrain | None = None
epoch_callback_test: TrainerEpochCallbackTest | None = None
stop_callback: TrainerStopCallback | None = None
class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
"""Sets the epsilon value for DQN-based policies at the beginning of the training
stage in each epoch.
"""
def __init__(self, eps_test: float):
self.eps_test = eps_test
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy = cast(DQNPolicy, context.policy)
policy.set_eps(self.eps_test)
class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
"""Sets the epsilon value for DQN-based policies at the beginning of the training
stage in each epoch, using a linear decay in the first `decay_steps` steps.
"""
def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000):
self.eps_train = eps_train
self.eps_train_final = eps_train_final
self.decay_steps = decay_steps
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy = cast(DQNPolicy, context.policy)
logger = context.logger
if env_step <= self.decay_steps:
eps = self.eps_train - env_step / self.decay_steps * (
self.eps_train - self.eps_train_final
)
else:
eps = self.eps_train_final
policy.set_eps(eps)
logger.write("train/env_step", env_step, {"train/eps": eps})
class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
"""Sets the epsilon value for DQN-based policies at the beginning of the test
stage in each epoch.
"""
def __init__(self, eps_test: float):
self.eps_test = eps_test
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
policy = cast(DQNPolicy, context.policy)
policy.set_eps(self.eps_test)

View File

@ -7,7 +7,7 @@ from typing import Any
import numpy as np
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray | float
VALID_LOG_VALS = typing.get_args(
VALID_LOG_VALS_TYPE,
) # I know it's stupid, but we can't use Union type in isinstance