Improve naming of callback classes and related methods/attributes

Add EpochStopCallbackRewardThreshold
This commit is contained in:
Dominik Jain 2024-01-10 15:28:48 +01:00
parent 24b7b82e56
commit 1e5ebc2a2d
8 changed files with 83 additions and 51 deletions

View File

@ -6,7 +6,7 @@ from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DQNExperimentBuilder,
@ -17,8 +17,8 @@ from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTestDQNSetEps,
TrainerEpochCallbackTrainDQNEpsLinearDecay,
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
@ -79,11 +79,11 @@ def main(
),
)
.with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train(
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
.with_epoch_train_callback(
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(

View File

@ -6,7 +6,7 @@ from collections.abc import Sequence
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
@ -14,8 +14,8 @@ from tianshou.highlevel.experiment import (
)
from tianshou.highlevel.params.policy_params import IQNParams
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTestDQNSetEps,
TrainerEpochCallbackTrainDQNEpsLinearDecay,
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
@ -83,11 +83,11 @@ def main(
),
)
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
.with_trainer_epoch_callback_train(
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
.with_epoch_train_callback(
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
.build()
)
experiment.run(log_name)

View File

@ -7,7 +7,7 @@ from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
@ -95,7 +95,7 @@ def main(
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
.with_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(

View File

@ -6,7 +6,7 @@ from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DiscreteSACExperimentBuilder,
@ -82,7 +82,7 @@ def main(
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
.with_common_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(

View File

@ -15,7 +15,7 @@ from tianshou.highlevel.env import (
EnvPoolFactory,
VectorEnvType,
)
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext
try:
import envpool
@ -387,7 +387,7 @@ class AtariEnvFactory(EnvFactoryGymnasium):
return kwargs
class AtariStopCallback(TrainerStopCallback):
class AtariEpochStopCallback(EpochStopCallback):
def __init__(self, task: str):
self.task = task

View File

@ -163,17 +163,19 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
callbacks = self.trainer_callbacks
context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train
callbacks.epoch_train_callback.get_trainer_fn(context)
if callbacks.epoch_train_callback
else None
)
test_fn = (
callbacks.epoch_callback_test.get_trainer_fn(context)
if callbacks.epoch_callback_test
callbacks.epoch_test_callback.get_trainer_fn(context)
if callbacks.epoch_test_callback
else None
)
stop_fn = (
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
callbacks.epoch_stop_callback.get_trainer_fn(context)
if callbacks.epoch_stop_callback
else None
)
return OnpolicyTrainer(
policy=world.policy,
@ -205,17 +207,19 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
callbacks = self.trainer_callbacks
context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train
callbacks.epoch_train_callback.get_trainer_fn(context)
if callbacks.epoch_train_callback
else None
)
test_fn = (
callbacks.epoch_callback_test.get_trainer_fn(context)
if callbacks.epoch_callback_test
callbacks.epoch_test_callback.get_trainer_fn(context)
if callbacks.epoch_test_callback
else None
)
stop_fn = (
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
callbacks.epoch_stop_callback.get_trainer_fn(context)
if callbacks.epoch_stop_callback
else None
)
return OffpolicyTrainer(
policy=world.policy,

View File

@ -70,10 +70,10 @@ from tianshou.highlevel.persistence import (
PolicyPersistence,
)
from tianshou.highlevel.trainer import (
EpochStopCallback,
EpochTestCallback,
EpochTrainCallback,
TrainerCallbacks,
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainerStopCallback,
)
from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy
@ -383,25 +383,25 @@ class ExperimentBuilder:
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
return self
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self:
"""Allows to define a callback function which is called at the beginning of every epoch during training.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.epoch_callback_train = callback
self._trainer_callbacks.epoch_train_callback = callback
return self
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
def with_epoch_test_callback(self, callback: EpochTestCallback) -> Self:
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.epoch_callback_test = callback
self._trainer_callbacks.epoch_test_callback = callback
return self
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self:
"""Allows to define a callback that decides whether training shall stop early.
The callback receives the undiscounted returns of the testing result.
@ -409,7 +409,7 @@ class ExperimentBuilder:
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.stop_callback = callback
self._trainer_callbacks.epoch_stop_callback = callback
return self
@abstractmethod

View File

@ -1,3 +1,4 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
@ -9,6 +10,7 @@ from tianshou.policy import BasePolicy, DQNPolicy
from tianshou.utils.string import ToStringMixin
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
log = logging.getLogger(__name__)
class TrainingContext:
@ -18,8 +20,10 @@ class TrainingContext:
self.logger = logger
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
class EpochTrainCallback(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase
of each epoch.
"""
@abstractmethod
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
@ -32,8 +36,8 @@ class TrainerEpochCallbackTrain(ToStringMixin, ABC):
return fn
class TrainerEpochCallbackTest(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
class EpochTestCallback(ToStringMixin, ABC):
"""Callback which is called at the beginning of the test phase of each epoch."""
@abstractmethod
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
@ -46,8 +50,10 @@ class TrainerEpochCallbackTest(ToStringMixin, ABC):
return fn
class TrainerStopCallback(ToStringMixin, ABC):
"""Callback indicating whether training should stop."""
class EpochStopCallback(ToStringMixin, ABC):
"""Callback which is called after the test phase of each epoch in order to determine
whether training should stop early.
"""
@abstractmethod
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
@ -69,12 +75,12 @@ class TrainerStopCallback(ToStringMixin, ABC):
class TrainerCallbacks:
"""Container for callbacks used during training."""
epoch_callback_train: TrainerEpochCallbackTrain | None = None
epoch_callback_test: TrainerEpochCallbackTest | None = None
stop_callback: TrainerStopCallback | None = None
epoch_train_callback: EpochTrainCallback | None = None
epoch_test_callback: EpochTestCallback | None = None
epoch_stop_callback: EpochStopCallback | None = None
class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
class EpochTrainCallbackDQNSetEps(EpochTrainCallback):
"""Sets the epsilon value for DQN-based policies at the beginning of the training
stage in each epoch.
"""
@ -87,7 +93,7 @@ class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
policy.set_eps(self.eps_test)
class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback):
"""Sets the epsilon value for DQN-based policies at the beginning of the training
stage in each epoch, using a linear decay in the first `decay_steps` steps.
"""
@ -110,7 +116,7 @@ class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
logger.write("train/env_step", env_step, {"train/eps": eps})
class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
class EpochTestCallbackDQNSetEps(EpochTestCallback):
"""Sets the epsilon value for DQN-based policies at the beginning of the test
stage in each epoch.
"""
@ -121,3 +127,25 @@ class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
policy = cast(DQNPolicy, context.policy)
policy.set_eps(self.eps_test)
class EpochStopCallbackRewardThreshold(EpochStopCallback):
"""Stops training once the mean rewards exceed the given reward threshold or the threshold that
is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`).
"""
def __init__(self, threshold: float | None = None):
""":param threshold: the reward threshold beyond which to stop training.
If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`.
"""
self.threshold = threshold
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
threshold = self.threshold
if threshold is None:
threshold = context.envs.env.spec.reward_threshold # type: ignore
assert threshold is not None
is_reached = mean_rewards >= threshold
if is_reached:
log.info(f"Reward threshold ({threshold}) exceeded")
return is_reached