Move callbacks for setting DQN epsilon values to the library
This commit is contained in:
parent
63269fe198
commit
ff398beed9
@ -1,33 +0,0 @@
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTest,
|
||||
TrainerEpochCallbackTrain,
|
||||
TrainingContext,
|
||||
)
|
||||
from tianshou.policy import DQNPolicy
|
||||
|
||||
|
||||
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
|
||||
def __init__(self, eps_test: float):
|
||||
self.eps_test = eps_test
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
policy.set_eps(self.eps_test)
|
||||
|
||||
|
||||
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
||||
def __init__(self, eps_train: float, eps_train_final: float):
|
||||
self.eps_train = eps_train
|
||||
self.eps_train_final = eps_train_final
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
logger = context.logger
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
|
||||
else:
|
||||
eps = self.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
@ -2,10 +2,6 @@
|
||||
|
||||
import os
|
||||
|
||||
from examples.atari.atari_callbacks import (
|
||||
TestEpochCallbackDQNSetEps,
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
||||
)
|
||||
from examples.atari.atari_network import (
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
@ -20,6 +16,10 @@ from tianshou.highlevel.params.policy_params import DQNParams
|
||||
from tianshou.highlevel.params.policy_wrapper import (
|
||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||
)
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTestDQNSetEps,
|
||||
TrainerEpochCallbackTrainDQNEpsLinearDecay,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
@ -80,9 +80,9 @@ def main(
|
||||
)
|
||||
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
||||
.with_trainer_epoch_callback_train(
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
)
|
||||
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
||||
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
|
@ -3,10 +3,6 @@
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
|
||||
from examples.atari.atari_callbacks import (
|
||||
TestEpochCallbackDQNSetEps,
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
||||
)
|
||||
from examples.atari.atari_network import (
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
)
|
||||
@ -17,6 +13,10 @@ from tianshou.highlevel.experiment import (
|
||||
IQNExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_params import IQNParams
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTestDQNSetEps,
|
||||
TrainerEpochCallbackTrainDQNEpsLinearDecay,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
@ -84,9 +84,9 @@ def main(
|
||||
)
|
||||
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
|
||||
.with_trainer_epoch_callback_train(
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
)
|
||||
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
||||
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.build()
|
||||
)
|
||||
|
@ -1,11 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar, cast
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import TLogger
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy import BasePolicy, DQNPolicy
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
@ -72,3 +72,52 @@ class TrainerCallbacks:
|
||||
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
||||
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
||||
stop_callback: TrainerStopCallback | None = None
|
||||
|
||||
|
||||
class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
|
||||
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||
stage in each epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, eps_test: float):
|
||||
self.eps_test = eps_test
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy = cast(DQNPolicy, context.policy)
|
||||
policy.set_eps(self.eps_test)
|
||||
|
||||
|
||||
class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
||||
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||
stage in each epoch, using a linear decay in the first `decay_steps` steps.
|
||||
"""
|
||||
|
||||
def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000):
|
||||
self.eps_train = eps_train
|
||||
self.eps_train_final = eps_train_final
|
||||
self.decay_steps = decay_steps
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy = cast(DQNPolicy, context.policy)
|
||||
logger = context.logger
|
||||
if env_step <= self.decay_steps:
|
||||
eps = self.eps_train - env_step / self.decay_steps * (
|
||||
self.eps_train - self.eps_train_final
|
||||
)
|
||||
else:
|
||||
eps = self.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
|
||||
class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
|
||||
"""Sets the epsilon value for DQN-based policies at the beginning of the test
|
||||
stage in each epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, eps_test: float):
|
||||
self.eps_test = eps_test
|
||||
|
||||
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||
policy = cast(DQNPolicy, context.policy)
|
||||
policy.set_eps(self.eps_test)
|
||||
|
@ -7,7 +7,7 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray
|
||||
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray | float
|
||||
VALID_LOG_VALS = typing.get_args(
|
||||
VALID_LOG_VALS_TYPE,
|
||||
) # I know it's stupid, but we can't use Union type in isinstance
|
||||
|
Loading…
x
Reference in New Issue
Block a user