Move callbacks for setting DQN epsilon values to the library
This commit is contained in:
parent
63269fe198
commit
ff398beed9
@ -1,33 +0,0 @@
|
|||||||
from tianshou.highlevel.trainer import (
|
|
||||||
TrainerEpochCallbackTest,
|
|
||||||
TrainerEpochCallbackTrain,
|
|
||||||
TrainingContext,
|
|
||||||
)
|
|
||||||
from tianshou.policy import DQNPolicy
|
|
||||||
|
|
||||||
|
|
||||||
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
|
|
||||||
def __init__(self, eps_test: float):
|
|
||||||
self.eps_test = eps_test
|
|
||||||
|
|
||||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
||||||
policy: DQNPolicy = context.policy
|
|
||||||
policy.set_eps(self.eps_test)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
|
||||||
def __init__(self, eps_train: float, eps_train_final: float):
|
|
||||||
self.eps_train = eps_train
|
|
||||||
self.eps_train_final = eps_train_final
|
|
||||||
|
|
||||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
||||||
policy: DQNPolicy = context.policy
|
|
||||||
logger = context.logger
|
|
||||||
# nature DQN setting, linear decay in the first 1M steps
|
|
||||||
if env_step <= 1e6:
|
|
||||||
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
|
|
||||||
else:
|
|
||||||
eps = self.eps_train_final
|
|
||||||
policy.set_eps(eps)
|
|
||||||
if env_step % 1000 == 0:
|
|
||||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
|
@ -2,10 +2,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from examples.atari.atari_callbacks import (
|
|
||||||
TestEpochCallbackDQNSetEps,
|
|
||||||
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
|
||||||
)
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
IntermediateModuleFactoryAtariDQN,
|
IntermediateModuleFactoryAtariDQN,
|
||||||
IntermediateModuleFactoryAtariDQNFeatures,
|
IntermediateModuleFactoryAtariDQNFeatures,
|
||||||
@ -20,6 +16,10 @@ from tianshou.highlevel.params.policy_params import DQNParams
|
|||||||
from tianshou.highlevel.params.policy_wrapper import (
|
from tianshou.highlevel.params.policy_wrapper import (
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.trainer import (
|
||||||
|
TrainerEpochCallbackTestDQNSetEps,
|
||||||
|
TrainerEpochCallbackTrainDQNEpsLinearDecay,
|
||||||
|
)
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
from tianshou.utils.logging import datetime_tag
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
@ -80,9 +80,9 @@ def main(
|
|||||||
)
|
)
|
||||||
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
||||||
.with_trainer_epoch_callback_train(
|
.with_trainer_epoch_callback_train(
|
||||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||||
)
|
)
|
||||||
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
)
|
)
|
||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
|
@ -3,10 +3,6 @@
|
|||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from examples.atari.atari_callbacks import (
|
|
||||||
TestEpochCallbackDQNSetEps,
|
|
||||||
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
|
||||||
)
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
IntermediateModuleFactoryAtariDQN,
|
IntermediateModuleFactoryAtariDQN,
|
||||||
)
|
)
|
||||||
@ -17,6 +13,10 @@ from tianshou.highlevel.experiment import (
|
|||||||
IQNExperimentBuilder,
|
IQNExperimentBuilder,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_params import IQNParams
|
from tianshou.highlevel.params.policy_params import IQNParams
|
||||||
|
from tianshou.highlevel.trainer import (
|
||||||
|
TrainerEpochCallbackTestDQNSetEps,
|
||||||
|
TrainerEpochCallbackTrainDQNEpsLinearDecay,
|
||||||
|
)
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
from tianshou.utils.logging import datetime_tag
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
@ -84,9 +84,9 @@ def main(
|
|||||||
)
|
)
|
||||||
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
|
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
|
||||||
.with_trainer_epoch_callback_train(
|
.with_trainer_epoch_callback_train(
|
||||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||||
)
|
)
|
||||||
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TypeVar
|
from typing import TypeVar, cast
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import TLogger
|
from tianshou.highlevel.logger import TLogger
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy, DQNPolicy
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
@ -72,3 +72,52 @@ class TrainerCallbacks:
|
|||||||
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
||||||
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
||||||
stop_callback: TrainerStopCallback | None = None
|
stop_callback: TrainerStopCallback | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
|
||||||
|
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||||
|
stage in each epoch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eps_test: float):
|
||||||
|
self.eps_test = eps_test
|
||||||
|
|
||||||
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
|
policy = cast(DQNPolicy, context.policy)
|
||||||
|
policy.set_eps(self.eps_test)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
||||||
|
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||||
|
stage in each epoch, using a linear decay in the first `decay_steps` steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000):
|
||||||
|
self.eps_train = eps_train
|
||||||
|
self.eps_train_final = eps_train_final
|
||||||
|
self.decay_steps = decay_steps
|
||||||
|
|
||||||
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
|
policy = cast(DQNPolicy, context.policy)
|
||||||
|
logger = context.logger
|
||||||
|
if env_step <= self.decay_steps:
|
||||||
|
eps = self.eps_train - env_step / self.decay_steps * (
|
||||||
|
self.eps_train - self.eps_train_final
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
eps = self.eps_train_final
|
||||||
|
policy.set_eps(eps)
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
|
||||||
|
"""Sets the epsilon value for DQN-based policies at the beginning of the test
|
||||||
|
stage in each epoch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eps_test: float):
|
||||||
|
self.eps_test = eps_test
|
||||||
|
|
||||||
|
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||||
|
policy = cast(DQNPolicy, context.policy)
|
||||||
|
policy.set_eps(self.eps_test)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray
|
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray | float
|
||||||
VALID_LOG_VALS = typing.get_args(
|
VALID_LOG_VALS = typing.get_args(
|
||||||
VALID_LOG_VALS_TYPE,
|
VALID_LOG_VALS_TYPE,
|
||||||
) # I know it's stupid, but we can't use Union type in isinstance
|
) # I know it's stupid, but we can't use Union type in isinstance
|
||||||
|
Loading…
x
Reference in New Issue
Block a user