Improve naming of callback classes and related methods/attributes
Add EpochStopCallbackRewardThreshold
This commit is contained in:
parent
24b7b82e56
commit
1e5ebc2a2d
@ -6,7 +6,7 @@ from examples.atari.atari_network import (
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
DQNExperimentBuilder,
|
||||
@ -17,8 +17,8 @@ from tianshou.highlevel.params.policy_wrapper import (
|
||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||
)
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTestDQNSetEps,
|
||||
TrainerEpochCallbackTrainDQNEpsLinearDecay,
|
||||
EpochTestCallbackDQNSetEps,
|
||||
EpochTrainCallbackDQNEpsLinearDecay,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
@ -79,11 +79,11 @@ def main(
|
||||
),
|
||||
)
|
||||
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
||||
.with_trainer_epoch_callback_train(
|
||||
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
.with_epoch_train_callback(
|
||||
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
)
|
||||
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
|
@ -6,7 +6,7 @@ from collections.abc import Sequence
|
||||
from examples.atari.atari_network import (
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
@ -14,8 +14,8 @@ from tianshou.highlevel.experiment import (
|
||||
)
|
||||
from tianshou.highlevel.params.policy_params import IQNParams
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTestDQNSetEps,
|
||||
TrainerEpochCallbackTrainDQNEpsLinearDecay,
|
||||
EpochTestCallbackDQNSetEps,
|
||||
EpochTrainCallbackDQNEpsLinearDecay,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
@ -83,11 +83,11 @@ def main(
|
||||
),
|
||||
)
|
||||
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
|
||||
.with_trainer_epoch_callback_train(
|
||||
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
.with_epoch_train_callback(
|
||||
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
)
|
||||
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
|
@ -7,7 +7,7 @@ from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
@ -95,7 +95,7 @@ def main(
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
|
||||
.with_critic_factory_use_actor()
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
|
@ -6,7 +6,7 @@ from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
DiscreteSACExperimentBuilder,
|
||||
@ -82,7 +82,7 @@ def main(
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
|
||||
.with_common_critic_factory_use_actor()
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
|
@ -15,7 +15,7 @@ from tianshou.highlevel.env import (
|
||||
EnvPoolFactory,
|
||||
VectorEnvType,
|
||||
)
|
||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||
from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext
|
||||
|
||||
try:
|
||||
import envpool
|
||||
@ -387,7 +387,7 @@ class AtariEnvFactory(EnvFactoryGymnasium):
|
||||
return kwargs
|
||||
|
||||
|
||||
class AtariStopCallback(TrainerStopCallback):
|
||||
class AtariEpochStopCallback(EpochStopCallback):
|
||||
def __init__(self, task: str):
|
||||
self.task = task
|
||||
|
||||
|
@ -163,17 +163,19 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
|
||||
callbacks = self.trainer_callbacks
|
||||
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||
train_fn = (
|
||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_train
|
||||
callbacks.epoch_train_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_train_callback
|
||||
else None
|
||||
)
|
||||
test_fn = (
|
||||
callbacks.epoch_callback_test.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_test
|
||||
callbacks.epoch_test_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_test_callback
|
||||
else None
|
||||
)
|
||||
stop_fn = (
|
||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||
callbacks.epoch_stop_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_stop_callback
|
||||
else None
|
||||
)
|
||||
return OnpolicyTrainer(
|
||||
policy=world.policy,
|
||||
@ -205,17 +207,19 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
|
||||
callbacks = self.trainer_callbacks
|
||||
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||
train_fn = (
|
||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_train
|
||||
callbacks.epoch_train_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_train_callback
|
||||
else None
|
||||
)
|
||||
test_fn = (
|
||||
callbacks.epoch_callback_test.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_test
|
||||
callbacks.epoch_test_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_test_callback
|
||||
else None
|
||||
)
|
||||
stop_fn = (
|
||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||
callbacks.epoch_stop_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_stop_callback
|
||||
else None
|
||||
)
|
||||
return OffpolicyTrainer(
|
||||
policy=world.policy,
|
||||
|
@ -70,10 +70,10 @@ from tianshou.highlevel.persistence import (
|
||||
PolicyPersistence,
|
||||
)
|
||||
from tianshou.highlevel.trainer import (
|
||||
EpochStopCallback,
|
||||
EpochTestCallback,
|
||||
EpochTrainCallback,
|
||||
TrainerCallbacks,
|
||||
TrainerEpochCallbackTest,
|
||||
TrainerEpochCallbackTrain,
|
||||
TrainerStopCallback,
|
||||
)
|
||||
from tianshou.highlevel.world import World
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -383,25 +383,25 @@ class ExperimentBuilder:
|
||||
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
|
||||
def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self:
|
||||
"""Allows to define a callback function which is called at the beginning of every epoch during training.
|
||||
|
||||
:param callback: the callback
|
||||
:return: the builder
|
||||
"""
|
||||
self._trainer_callbacks.epoch_callback_train = callback
|
||||
self._trainer_callbacks.epoch_train_callback = callback
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
|
||||
def with_epoch_test_callback(self, callback: EpochTestCallback) -> Self:
|
||||
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
|
||||
|
||||
:param callback: the callback
|
||||
:return: the builder
|
||||
"""
|
||||
self._trainer_callbacks.epoch_callback_test = callback
|
||||
self._trainer_callbacks.epoch_test_callback = callback
|
||||
return self
|
||||
|
||||
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
||||
def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self:
|
||||
"""Allows to define a callback that decides whether training shall stop early.
|
||||
|
||||
The callback receives the undiscounted returns of the testing result.
|
||||
@ -409,7 +409,7 @@ class ExperimentBuilder:
|
||||
:param callback: the callback
|
||||
:return: the builder
|
||||
"""
|
||||
self._trainer_callbacks.stop_callback = callback
|
||||
self._trainer_callbacks.epoch_stop_callback = callback
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
@ -9,6 +10,7 @@ from tianshou.policy import BasePolicy, DQNPolicy
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingContext:
|
||||
@ -18,8 +20,10 @@ class TrainingContext:
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of each epoch."""
|
||||
class EpochTrainCallback(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase
|
||||
of each epoch.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
@ -32,8 +36,8 @@ class TrainerEpochCallbackTrain(ToStringMixin, ABC):
|
||||
return fn
|
||||
|
||||
|
||||
class TrainerEpochCallbackTest(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of each epoch."""
|
||||
class EpochTestCallback(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of the test phase of each epoch."""
|
||||
|
||||
@abstractmethod
|
||||
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||
@ -46,8 +50,10 @@ class TrainerEpochCallbackTest(ToStringMixin, ABC):
|
||||
return fn
|
||||
|
||||
|
||||
class TrainerStopCallback(ToStringMixin, ABC):
|
||||
"""Callback indicating whether training should stop."""
|
||||
class EpochStopCallback(ToStringMixin, ABC):
|
||||
"""Callback which is called after the test phase of each epoch in order to determine
|
||||
whether training should stop early.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||
@ -69,12 +75,12 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
||||
class TrainerCallbacks:
|
||||
"""Container for callbacks used during training."""
|
||||
|
||||
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
||||
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
||||
stop_callback: TrainerStopCallback | None = None
|
||||
epoch_train_callback: EpochTrainCallback | None = None
|
||||
epoch_test_callback: EpochTestCallback | None = None
|
||||
epoch_stop_callback: EpochStopCallback | None = None
|
||||
|
||||
|
||||
class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
|
||||
class EpochTrainCallbackDQNSetEps(EpochTrainCallback):
|
||||
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||
stage in each epoch.
|
||||
"""
|
||||
@ -87,7 +93,7 @@ class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
|
||||
policy.set_eps(self.eps_test)
|
||||
|
||||
|
||||
class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
||||
class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback):
|
||||
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||
stage in each epoch, using a linear decay in the first `decay_steps` steps.
|
||||
"""
|
||||
@ -110,7 +116,7 @@ class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
|
||||
class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
|
||||
class EpochTestCallbackDQNSetEps(EpochTestCallback):
|
||||
"""Sets the epsilon value for DQN-based policies at the beginning of the test
|
||||
stage in each epoch.
|
||||
"""
|
||||
@ -121,3 +127,25 @@ class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
|
||||
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||
policy = cast(DQNPolicy, context.policy)
|
||||
policy.set_eps(self.eps_test)
|
||||
|
||||
|
||||
class EpochStopCallbackRewardThreshold(EpochStopCallback):
|
||||
"""Stops training once the mean rewards exceed the given reward threshold or the threshold that
|
||||
is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`).
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: float | None = None):
|
||||
""":param threshold: the reward threshold beyond which to stop training.
|
||||
If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`.
|
||||
"""
|
||||
self.threshold = threshold
|
||||
|
||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||
threshold = self.threshold
|
||||
if threshold is None:
|
||||
threshold = context.envs.env.spec.reward_threshold # type: ignore
|
||||
assert threshold is not None
|
||||
is_reached = mean_rewards >= threshold
|
||||
if is_reached:
|
||||
log.info(f"Reward threshold ({threshold}) exceeded")
|
||||
return is_reached
|
||||
|
Loading…
x
Reference in New Issue
Block a user