152 lines
5.3 KiB
Python
152 lines
5.3 KiB
Python
import logging
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import TypeVar, cast
|
|
|
|
from tianshou.highlevel.env import Environments
|
|
from tianshou.highlevel.logger import TLogger
|
|
from tianshou.policy import BasePolicy, DQNPolicy
|
|
from tianshou.utils.string import ToStringMixin
|
|
|
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class TrainingContext:
|
|
def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger):
|
|
self.policy = policy
|
|
self.envs = envs
|
|
self.logger = logger
|
|
|
|
|
|
class EpochTrainCallback(ToStringMixin, ABC):
|
|
"""Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase
|
|
of each epoch.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
pass
|
|
|
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
|
|
def fn(epoch: int, env_step: int) -> None:
|
|
return self.callback(epoch, env_step, context)
|
|
|
|
return fn
|
|
|
|
|
|
class EpochTestCallback(ToStringMixin, ABC):
|
|
"""Callback which is called at the beginning of the test phase of each epoch."""
|
|
|
|
@abstractmethod
|
|
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
|
pass
|
|
|
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]:
|
|
def fn(epoch: int, env_step: int | None) -> None:
|
|
return self.callback(epoch, env_step, context)
|
|
|
|
return fn
|
|
|
|
|
|
class EpochStopCallback(ToStringMixin, ABC):
|
|
"""Callback which is called after the test phase of each epoch in order to determine
|
|
whether training should stop early.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
|
"""Determines whether training should stop.
|
|
|
|
:param mean_rewards: the average undiscounted returns of the testing result
|
|
:param context: the training context
|
|
:return: True if the goal has been reached and training should stop, False otherwise
|
|
"""
|
|
|
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
|
|
def fn(mean_rewards: float) -> bool:
|
|
return self.should_stop(mean_rewards, context)
|
|
|
|
return fn
|
|
|
|
|
|
@dataclass
|
|
class TrainerCallbacks:
|
|
"""Container for callbacks used during training."""
|
|
|
|
epoch_train_callback: EpochTrainCallback | None = None
|
|
epoch_test_callback: EpochTestCallback | None = None
|
|
epoch_stop_callback: EpochStopCallback | None = None
|
|
|
|
|
|
class EpochTrainCallbackDQNSetEps(EpochTrainCallback):
|
|
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
|
stage in each epoch.
|
|
"""
|
|
|
|
def __init__(self, eps_test: float):
|
|
self.eps_test = eps_test
|
|
|
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
policy = cast(DQNPolicy, context.policy)
|
|
policy.set_eps(self.eps_test)
|
|
|
|
|
|
class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback):
|
|
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
|
stage in each epoch, using a linear decay in the first `decay_steps` steps.
|
|
"""
|
|
|
|
def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000):
|
|
self.eps_train = eps_train
|
|
self.eps_train_final = eps_train_final
|
|
self.decay_steps = decay_steps
|
|
|
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
|
policy = cast(DQNPolicy, context.policy)
|
|
logger = context.logger
|
|
if env_step <= self.decay_steps:
|
|
eps = self.eps_train - env_step / self.decay_steps * (
|
|
self.eps_train - self.eps_train_final
|
|
)
|
|
else:
|
|
eps = self.eps_train_final
|
|
policy.set_eps(eps)
|
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
|
|
|
|
|
class EpochTestCallbackDQNSetEps(EpochTestCallback):
|
|
"""Sets the epsilon value for DQN-based policies at the beginning of the test
|
|
stage in each epoch.
|
|
"""
|
|
|
|
def __init__(self, eps_test: float):
|
|
self.eps_test = eps_test
|
|
|
|
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
|
policy = cast(DQNPolicy, context.policy)
|
|
policy.set_eps(self.eps_test)
|
|
|
|
|
|
class EpochStopCallbackRewardThreshold(EpochStopCallback):
|
|
"""Stops training once the mean rewards exceed the given reward threshold or the threshold that
|
|
is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`).
|
|
"""
|
|
|
|
def __init__(self, threshold: float | None = None):
|
|
""":param threshold: the reward threshold beyond which to stop training.
|
|
If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`.
|
|
"""
|
|
self.threshold = threshold
|
|
|
|
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
|
threshold = self.threshold
|
|
if threshold is None:
|
|
threshold = context.envs.env.spec.reward_threshold # type: ignore
|
|
assert threshold is not None
|
|
is_reached = mean_rewards >= threshold
|
|
if is_reached:
|
|
log.info(f"Reward threshold ({threshold}) exceeded")
|
|
return is_reached
|