Improve naming of callback classes and related methods/attributes
Add EpochStopCallbackRewardThreshold
This commit is contained in:
parent
24b7b82e56
commit
1e5ebc2a2d
@ -6,7 +6,7 @@ from examples.atari.atari_network import (
|
|||||||
IntermediateModuleFactoryAtariDQN,
|
IntermediateModuleFactoryAtariDQN,
|
||||||
IntermediateModuleFactoryAtariDQNFeatures,
|
IntermediateModuleFactoryAtariDQNFeatures,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
DQNExperimentBuilder,
|
DQNExperimentBuilder,
|
||||||
@ -17,8 +17,8 @@ from tianshou.highlevel.params.policy_wrapper import (
|
|||||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.trainer import (
|
from tianshou.highlevel.trainer import (
|
||||||
TrainerEpochCallbackTestDQNSetEps,
|
EpochTestCallbackDQNSetEps,
|
||||||
TrainerEpochCallbackTrainDQNEpsLinearDecay,
|
EpochTrainCallbackDQNEpsLinearDecay,
|
||||||
)
|
)
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
from tianshou.utils.logging import datetime_tag
|
from tianshou.utils.logging import datetime_tag
|
||||||
@ -79,11 +79,11 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
||||||
.with_trainer_epoch_callback_train(
|
.with_epoch_train_callback(
|
||||||
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
|
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||||
)
|
)
|
||||||
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
|
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||||
)
|
)
|
||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
|
@ -6,7 +6,7 @@ from collections.abc import Sequence
|
|||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
IntermediateModuleFactoryAtariDQN,
|
IntermediateModuleFactoryAtariDQN,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
@ -14,8 +14,8 @@ from tianshou.highlevel.experiment import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_params import IQNParams
|
from tianshou.highlevel.params.policy_params import IQNParams
|
||||||
from tianshou.highlevel.trainer import (
|
from tianshou.highlevel.trainer import (
|
||||||
TrainerEpochCallbackTestDQNSetEps,
|
EpochTestCallbackDQNSetEps,
|
||||||
TrainerEpochCallbackTrainDQNEpsLinearDecay,
|
EpochTrainCallbackDQNEpsLinearDecay,
|
||||||
)
|
)
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
from tianshou.utils.logging import datetime_tag
|
from tianshou.utils.logging import datetime_tag
|
||||||
@ -83,11 +83,11 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
|
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
|
||||||
.with_trainer_epoch_callback_train(
|
.with_epoch_train_callback(
|
||||||
TrainerEpochCallbackTrainDQNEpsLinearDecay(eps_train, eps_train_final),
|
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||||
)
|
)
|
||||||
.with_trainer_epoch_callback_test(TrainerEpochCallbackTestDQNSetEps(eps_test))
|
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
@ -7,7 +7,7 @@ from examples.atari.atari_network import (
|
|||||||
ActorFactoryAtariDQN,
|
ActorFactoryAtariDQN,
|
||||||
IntermediateModuleFactoryAtariDQNFeatures,
|
IntermediateModuleFactoryAtariDQNFeatures,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
@ -95,7 +95,7 @@ def main(
|
|||||||
)
|
)
|
||||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
|
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
|
||||||
.with_critic_factory_use_actor()
|
.with_critic_factory_use_actor()
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||||
)
|
)
|
||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
|
@ -6,7 +6,7 @@ from examples.atari.atari_network import (
|
|||||||
ActorFactoryAtariDQN,
|
ActorFactoryAtariDQN,
|
||||||
IntermediateModuleFactoryAtariDQNFeatures,
|
IntermediateModuleFactoryAtariDQNFeatures,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
DiscreteSACExperimentBuilder,
|
DiscreteSACExperimentBuilder,
|
||||||
@ -82,7 +82,7 @@ def main(
|
|||||||
)
|
)
|
||||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
|
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
|
||||||
.with_common_critic_factory_use_actor()
|
.with_common_critic_factory_use_actor()
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||||
)
|
)
|
||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
|
@ -15,7 +15,7 @@ from tianshou.highlevel.env import (
|
|||||||
EnvPoolFactory,
|
EnvPoolFactory,
|
||||||
VectorEnvType,
|
VectorEnvType,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
@ -387,7 +387,7 @@ class AtariEnvFactory(EnvFactoryGymnasium):
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
class AtariStopCallback(TrainerStopCallback):
|
class AtariEpochStopCallback(EpochStopCallback):
|
||||||
def __init__(self, task: str):
|
def __init__(self, task: str):
|
||||||
self.task = task
|
self.task = task
|
||||||
|
|
||||||
|
@ -163,17 +163,19 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
|
|||||||
callbacks = self.trainer_callbacks
|
callbacks = self.trainer_callbacks
|
||||||
context = TrainingContext(world.policy, world.envs, world.logger)
|
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||||
train_fn = (
|
train_fn = (
|
||||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
callbacks.epoch_train_callback.get_trainer_fn(context)
|
||||||
if callbacks.epoch_callback_train
|
if callbacks.epoch_train_callback
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
test_fn = (
|
test_fn = (
|
||||||
callbacks.epoch_callback_test.get_trainer_fn(context)
|
callbacks.epoch_test_callback.get_trainer_fn(context)
|
||||||
if callbacks.epoch_callback_test
|
if callbacks.epoch_test_callback
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
stop_fn = (
|
stop_fn = (
|
||||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
callbacks.epoch_stop_callback.get_trainer_fn(context)
|
||||||
|
if callbacks.epoch_stop_callback
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
return OnpolicyTrainer(
|
return OnpolicyTrainer(
|
||||||
policy=world.policy,
|
policy=world.policy,
|
||||||
@ -205,17 +207,19 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
|
|||||||
callbacks = self.trainer_callbacks
|
callbacks = self.trainer_callbacks
|
||||||
context = TrainingContext(world.policy, world.envs, world.logger)
|
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||||
train_fn = (
|
train_fn = (
|
||||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
callbacks.epoch_train_callback.get_trainer_fn(context)
|
||||||
if callbacks.epoch_callback_train
|
if callbacks.epoch_train_callback
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
test_fn = (
|
test_fn = (
|
||||||
callbacks.epoch_callback_test.get_trainer_fn(context)
|
callbacks.epoch_test_callback.get_trainer_fn(context)
|
||||||
if callbacks.epoch_callback_test
|
if callbacks.epoch_test_callback
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
stop_fn = (
|
stop_fn = (
|
||||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
callbacks.epoch_stop_callback.get_trainer_fn(context)
|
||||||
|
if callbacks.epoch_stop_callback
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
return OffpolicyTrainer(
|
return OffpolicyTrainer(
|
||||||
policy=world.policy,
|
policy=world.policy,
|
||||||
|
@ -70,10 +70,10 @@ from tianshou.highlevel.persistence import (
|
|||||||
PolicyPersistence,
|
PolicyPersistence,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.trainer import (
|
from tianshou.highlevel.trainer import (
|
||||||
|
EpochStopCallback,
|
||||||
|
EpochTestCallback,
|
||||||
|
EpochTrainCallback,
|
||||||
TrainerCallbacks,
|
TrainerCallbacks,
|
||||||
TrainerEpochCallbackTest,
|
|
||||||
TrainerEpochCallbackTrain,
|
|
||||||
TrainerStopCallback,
|
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.world import World
|
from tianshou.highlevel.world import World
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
@ -383,25 +383,25 @@ class ExperimentBuilder:
|
|||||||
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
|
def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self:
|
||||||
"""Allows to define a callback function which is called at the beginning of every epoch during training.
|
"""Allows to define a callback function which is called at the beginning of every epoch during training.
|
||||||
|
|
||||||
:param callback: the callback
|
:param callback: the callback
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
self._trainer_callbacks.epoch_callback_train = callback
|
self._trainer_callbacks.epoch_train_callback = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
|
def with_epoch_test_callback(self, callback: EpochTestCallback) -> Self:
|
||||||
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
|
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
|
||||||
|
|
||||||
:param callback: the callback
|
:param callback: the callback
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
self._trainer_callbacks.epoch_callback_test = callback
|
self._trainer_callbacks.epoch_test_callback = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self:
|
||||||
"""Allows to define a callback that decides whether training shall stop early.
|
"""Allows to define a callback that decides whether training shall stop early.
|
||||||
|
|
||||||
The callback receives the undiscounted returns of the testing result.
|
The callback receives the undiscounted returns of the testing result.
|
||||||
@ -409,7 +409,7 @@ class ExperimentBuilder:
|
|||||||
:param callback: the callback
|
:param callback: the callback
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
self._trainer_callbacks.stop_callback = callback
|
self._trainer_callbacks.epoch_stop_callback = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -9,6 +10,7 @@ from tianshou.policy import BasePolicy, DQNPolicy
|
|||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TrainingContext:
|
class TrainingContext:
|
||||||
@ -18,8 +20,10 @@ class TrainingContext:
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
|
||||||
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
|
class EpochTrainCallback(ToStringMixin, ABC):
|
||||||
"""Callback which is called at the beginning of each epoch."""
|
"""Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase
|
||||||
|
of each epoch.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
@ -32,8 +36,8 @@ class TrainerEpochCallbackTrain(ToStringMixin, ABC):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
class TrainerEpochCallbackTest(ToStringMixin, ABC):
|
class EpochTestCallback(ToStringMixin, ABC):
|
||||||
"""Callback which is called at the beginning of each epoch."""
|
"""Callback which is called at the beginning of the test phase of each epoch."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||||
@ -46,8 +50,10 @@ class TrainerEpochCallbackTest(ToStringMixin, ABC):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
class TrainerStopCallback(ToStringMixin, ABC):
|
class EpochStopCallback(ToStringMixin, ABC):
|
||||||
"""Callback indicating whether training should stop."""
|
"""Callback which is called after the test phase of each epoch in order to determine
|
||||||
|
whether training should stop early.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||||
@ -69,12 +75,12 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
|||||||
class TrainerCallbacks:
|
class TrainerCallbacks:
|
||||||
"""Container for callbacks used during training."""
|
"""Container for callbacks used during training."""
|
||||||
|
|
||||||
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
epoch_train_callback: EpochTrainCallback | None = None
|
||||||
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
epoch_test_callback: EpochTestCallback | None = None
|
||||||
stop_callback: TrainerStopCallback | None = None
|
epoch_stop_callback: EpochStopCallback | None = None
|
||||||
|
|
||||||
|
|
||||||
class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
|
class EpochTrainCallbackDQNSetEps(EpochTrainCallback):
|
||||||
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||||
stage in each epoch.
|
stage in each epoch.
|
||||||
"""
|
"""
|
||||||
@ -87,7 +93,7 @@ class TrainerEpochCallbackTrainDQNSetEps(TrainerEpochCallbackTrain):
|
|||||||
policy.set_eps(self.eps_test)
|
policy.set_eps(self.eps_test)
|
||||||
|
|
||||||
|
|
||||||
class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback):
|
||||||
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||||
stage in each epoch, using a linear decay in the first `decay_steps` steps.
|
stage in each epoch, using a linear decay in the first `decay_steps` steps.
|
||||||
"""
|
"""
|
||||||
@ -110,7 +116,7 @@ class TrainerEpochCallbackTrainDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
|||||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
|
|
||||||
class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
|
class EpochTestCallbackDQNSetEps(EpochTestCallback):
|
||||||
"""Sets the epsilon value for DQN-based policies at the beginning of the test
|
"""Sets the epsilon value for DQN-based policies at the beginning of the test
|
||||||
stage in each epoch.
|
stage in each epoch.
|
||||||
"""
|
"""
|
||||||
@ -121,3 +127,25 @@ class TrainerEpochCallbackTestDQNSetEps(TrainerEpochCallbackTest):
|
|||||||
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||||
policy = cast(DQNPolicy, context.policy)
|
policy = cast(DQNPolicy, context.policy)
|
||||||
policy.set_eps(self.eps_test)
|
policy.set_eps(self.eps_test)
|
||||||
|
|
||||||
|
|
||||||
|
class EpochStopCallbackRewardThreshold(EpochStopCallback):
|
||||||
|
"""Stops training once the mean rewards exceed the given reward threshold or the threshold that
|
||||||
|
is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, threshold: float | None = None):
|
||||||
|
""":param threshold: the reward threshold beyond which to stop training.
|
||||||
|
If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`.
|
||||||
|
"""
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||||
|
threshold = self.threshold
|
||||||
|
if threshold is None:
|
||||||
|
threshold = context.envs.env.spec.reward_threshold # type: ignore
|
||||||
|
assert threshold is not None
|
||||||
|
is_reached = mean_rewards >= threshold
|
||||||
|
if is_reached:
|
||||||
|
log.info(f"Reward threshold ({threshold}) exceeded")
|
||||||
|
return is_reached
|
||||||
|
Loading…
x
Reference in New Issue
Block a user