Improve naming of callback classes and related methods/attributes

Add EpochStopCallbackRewardThreshold
This commit is contained in:
Dominik Jain 2024-01-10 15:28:48 +01:00
parent 24b7b82e56
commit 1e5ebc2a2d
8 changed files with 83 additions and 51 deletions

View File

@ -6,7 +6,7 @@ from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures, IntermediateModuleFactoryAtariDQNFeatures,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
DQNExperimentBuilder, DQNExperimentBuilder,
@ -17,8 +17,8 @@ from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity, PolicyWrapperFactoryIntrinsicCuriosity,
) )
from tianshou.highlevel.trainer import ( from tianshou.highlevel.trainer import (
TrainerEpochCallbackTestDQNSetEps, EpochTestCallbackDQNSetEps,
TrainerEpochCallbackTrainDQNEpsLinearDecay, EpochTrainCallbackDQNEpsLinearDecay,
) )
from tianshou.utils import logging from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag from tianshou.utils.logging import datetime_tag
@ -79,11 +79,11 @@ def main(
), ),
) )
.with_model_factory(IntermediateModuleFactoryAtariDQN()) .with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train( .with_epoch_train_callback(
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final), EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
) )
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test)) .with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task)) .with_epoch_stop_callback(AtariEpochStopCallback(task))
) )
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(

View File

@ -6,7 +6,7 @@ from collections.abc import Sequence
from examples.atari.atari_network import ( from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQN,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
ExperimentConfig, ExperimentConfig,
@ -14,8 +14,8 @@ from tianshou.highlevel.experiment import (
) )
from tianshou.highlevel.params.policy_params import IQNParams from tianshou.highlevel.params.policy_params import IQNParams
from tianshou.highlevel.trainer import ( from tianshou.highlevel.trainer import (
TrainerEpochCallbackTestDQNSetEps, EpochTestCallbackDQNSetEps,
TrainerEpochCallbackTrainDQNEpsLinearDecay, EpochTrainCallbackDQNEpsLinearDecay,
) )
from tianshou.utils import logging from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag from tianshou.utils.logging import datetime_tag
@ -83,11 +83,11 @@ def main(
), ),
) )
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True)) .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
.with_trainer_epoch_callback_train( .with_epoch_train_callback(
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final), EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
) )
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test)) .with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task)) .with_epoch_stop_callback(AtariEpochStopCallback(task))
.build() .build()
) )
experiment.run(log_name) experiment.run(log_name)

View File

@ -7,7 +7,7 @@ from examples.atari.atari_network import (
ActorFactoryAtariDQN, ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures, IntermediateModuleFactoryAtariDQNFeatures,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
ExperimentConfig, ExperimentConfig,
@ -95,7 +95,7 @@ def main(
) )
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True)) .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
.with_critic_factory_use_actor() .with_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task)) .with_epoch_stop_callback(AtariEpochStopCallback(task))
) )
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(

View File

@ -6,7 +6,7 @@ from examples.atari.atari_network import (
ActorFactoryAtariDQN, ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures, IntermediateModuleFactoryAtariDQNFeatures,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
DiscreteSACExperimentBuilder, DiscreteSACExperimentBuilder,
@ -82,7 +82,7 @@ def main(
) )
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True)) .with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
.with_common_critic_factory_use_actor() .with_common_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task)) .with_epoch_stop_callback(AtariEpochStopCallback(task))
) )
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(

View File

@ -15,7 +15,7 @@ from tianshou.highlevel.env import (
EnvPoolFactory, EnvPoolFactory,
VectorEnvType, VectorEnvType,
) )
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext
try: try:
import envpool import envpool
@ -387,7 +387,7 @@ class AtariEnvFactory(EnvFactoryGymnasium):
return kwargs return kwargs
class AtariStopCallback(TrainerStopCallback): class AtariEpochStopCallback(EpochStopCallback):
def __init__(self, task: str): def __init__(self, task: str):
self.task = task self.task = task

View File

@ -163,17 +163,19 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
callbacks = self.trainer_callbacks callbacks = self.trainer_callbacks
context = TrainingContext(world.policy, world.envs, world.logger) context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = ( train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context) callbacks.epoch_train_callback.get_trainer_fn(context)
if callbacks.epoch_callback_train if callbacks.epoch_train_callback
else None else None
) )
test_fn = ( test_fn = (
callbacks.epoch_callback_test.get_trainer_fn(context) callbacks.epoch_test_callback.get_trainer_fn(context)
if callbacks.epoch_callback_test if callbacks.epoch_test_callback
else None else None
) )
stop_fn = ( stop_fn = (
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None callbacks.epoch_stop_callback.get_trainer_fn(context)
if callbacks.epoch_stop_callback
else None
) )
return OnpolicyTrainer( return OnpolicyTrainer(
policy=world.policy, policy=world.policy,
@ -205,17 +207,19 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
callbacks = self.trainer_callbacks callbacks = self.trainer_callbacks
context = TrainingContext(world.policy, world.envs, world.logger) context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = ( train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context) callbacks.epoch_train_callback.get_trainer_fn(context)
if callbacks.epoch_callback_train if callbacks.epoch_train_callback
else None else None
) )
test_fn = ( test_fn = (
callbacks.epoch_callback_test.get_trainer_fn(context) callbacks.epoch_test_callback.get_trainer_fn(context)
if callbacks.epoch_callback_test if callbacks.epoch_test_callback
else None else None
) )
stop_fn = ( stop_fn = (
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None callbacks.epoch_stop_callback.get_trainer_fn(context)
if callbacks.epoch_stop_callback
else None
) )
return OffpolicyTrainer( return OffpolicyTrainer(
policy=world.policy, policy=world.policy,

View File

@ -70,10 +70,10 @@ from tianshou.highlevel.persistence import (
PolicyPersistence, PolicyPersistence,
) )
from tianshou.highlevel.trainer import ( from tianshou.highlevel.trainer import (
EpochStopCallback,
EpochTestCallback,
EpochTrainCallback,
TrainerCallbacks, TrainerCallbacks,
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainerStopCallback,
) )
from tianshou.highlevel.world import World from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
@ -383,25 +383,25 @@ class ExperimentBuilder:
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
return self return self
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self: def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self:
"""Allows to define a callback function which is called at the beginning of every epoch during training. """Allows to define a callback function which is called at the beginning of every epoch during training.
:param callback: the callback :param callback: the callback
:return: the builder :return: the builder
""" """
self._trainer_callbacks.epoch_callback_train = callback self._trainer_callbacks.epoch_train_callback = callback
return self return self
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self: def with_epoch_test_callback(self, callback: EpochTestCallback) -> Self:
"""Allows to define a callback function which is called at the beginning of testing in each epoch. """Allows to define a callback function which is called at the beginning of testing in each epoch.
:param callback: the callback :param callback: the callback
:return: the builder :return: the builder
""" """
self._trainer_callbacks.epoch_callback_test = callback self._trainer_callbacks.epoch_test_callback = callback
return self return self
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self: def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self:
"""Allows to define a callback that decides whether training shall stop early. """Allows to define a callback that decides whether training shall stop early.
The callback receives the undiscounted returns of the testing result. The callback receives the undiscounted returns of the testing result.
@ -409,7 +409,7 @@ class ExperimentBuilder:
:param callback: the callback :param callback: the callback
:return: the builder :return: the builder
""" """
self._trainer_callbacks.stop_callback = callback self._trainer_callbacks.epoch_stop_callback = callback
return self return self
@abstractmethod @abstractmethod

View File

@ -1,3 +1,4 @@
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
@ -9,6 +10,7 @@ from tianshou.policy import BasePolicy, DQNPolicy
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
log = logging.getLogger(__name__)
class TrainingContext: class TrainingContext:
@ -18,8 +20,10 @@ class TrainingContext:
self.logger = logger self.logger = logger
class TrainerEpochCallbackTrain(ToStringMixin, ABC): class EpochTrainCallback(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch.""" """Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase
of each epoch.
"""
@abstractmethod @abstractmethod
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
@ -32,8 +36,8 @@ class TrainerEpochCallbackTrain(ToStringMixin, ABC):
return fn return fn
class TrainerEpochCallbackTest(ToStringMixin, ABC): class EpochTestCallback(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch.""" """Callback which is called at the beginning of the test phase of each epoch."""
@abstractmethod @abstractmethod
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
@ -46,8 +50,10 @@ class TrainerEpochCallbackTest(ToStringMixin, ABC):
return fn return fn
class TrainerStopCallback(ToStringMixin, ABC): class EpochStopCallback(ToStringMixin, ABC):
"""Callback indicating whether training should stop.""" """Callback which is called after the test phase of each epoch in order to determine
whether training should stop early.
"""
@abstractmethod @abstractmethod
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
@ -69,12 +75,12 @@ class TrainerStopCallback(ToStringMixin, ABC):
class TrainerCallbacks: class TrainerCallbacks:
"""Container for callbacks used during training.""" """Container for callbacks used during training."""
epoch_callback_train: TrainerEpochCallbackTrain | None = None epoch_train_callback: EpochTrainCallback | None = None
epoch_callback_test: TrainerEpochCallbackTest | None = None epoch_test_callback: EpochTestCallback | None = None
stop_callback: TrainerStopCallback | None = None epoch_stop_callback: EpochStopCallback | None = None
class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain): class EpochTrainCallbackDQNSetEps(EpochTrainCallback):
"""Sets the epsilon value for DQN-based policies at the beginning of the training """Sets the epsilon value for DQN-based policies at the beginning of the training
stage in each epoch. stage in each epoch.
""" """
@ -87,7 +93,7 @@ class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
policy.set_eps(self.eps_test) policy.set_eps(self.eps_test)
class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain): class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback):
"""Sets the epsilon value for DQN-based policies at the beginning of the training """Sets the epsilon value for DQN-based policies at the beginning of the training
stage in each epoch, using a linear decay in the first `decay_steps` steps. stage in each epoch, using a linear decay in the first `decay_steps` steps.
""" """
@ -110,7 +116,7 @@ class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
logger.write("train/env_step", env_step, {"train/eps": eps}) logger.write("train/env_step", env_step, {"train/eps": eps})
class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest): class EpochTestCallbackDQNSetEps(EpochTestCallback):
"""Sets the epsilon value for DQN-based policies at the beginning of the test """Sets the epsilon value for DQN-based policies at the beginning of the test
stage in each epoch. stage in each epoch.
""" """
@ -121,3 +127,25 @@ class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
policy = cast(DQNPolicy, context.policy) policy = cast(DQNPolicy, context.policy)
policy.set_eps(self.eps_test) policy.set_eps(self.eps_test)
class EpochStopCallbackRewardThreshold(EpochStopCallback):
"""Stops training once the mean rewards exceed the given reward threshold or the threshold that
is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`).
"""
def __init__(self, threshold: float | None = None):
""":param threshold: the reward threshold beyond which to stop training.
If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`.
"""
self.threshold = threshold
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
threshold = self.threshold
if threshold is None:
threshold = context.envs.env.spec.reward_threshold # type: ignore
assert threshold is not None
is_reached = mean_rewards >= threshold
if is_reached:
log.info(f"Reward threshold ({threshold}) exceeded")
return is_reached