diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py new file mode 100644 index 0000000..42583d8 --- /dev/null +++ b/examples/atari/atari_dqn_hl.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 + +import datetime +import os + +from jsonargparse import CLI + +from examples.atari.atari_network import ( + CriticFactoryAtariDQN, + FeatureNetFactoryDQN, +) +from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback +from tianshou.highlevel.config import RLSamplingConfig +from tianshou.highlevel.experiment import ( + DQNExperimentBuilder, + RLExperimentConfig, +) +from tianshou.highlevel.params.policy_params import DQNParams +from tianshou.highlevel.params.policy_wrapper import ( + PolicyWrapperFactoryIntrinsicCuriosity, +) +from tianshou.highlevel.trainer import TrainerEpochCallback, TrainingContext +from tianshou.policy import DQNPolicy +from tianshou.utils import logging + + +def main( + experiment_config: RLExperimentConfig, + task: str = "PongNoFrameskip-v4", + scale_obs: int = 0, + eps_test: float = 0.005, + eps_train: float = 1.0, + eps_train_final: float = 0.05, + buffer_size: int = 100000, + lr: float = 0.0001, + gamma: float = 0.99, + n_step: int = 3, + target_update_freq: int = 500, + epoch: int = 100, + step_per_epoch: int = 100000, + step_per_collect: int = 10, + update_per_step: float = 0.1, + batch_size: int = 32, + training_num: int = 10, + test_num: int = 10, + frames_stack: int = 4, + save_buffer_name: str | None = None, # TODO support? + icm_lr_scale: float = 0.0, + icm_reward_scale: float = 0.01, + icm_forward_loss_weight: float = 0.2, +): + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) + + sampling_config = RLSamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + update_per_step=update_per_step, + repeat_per_collect=None, + replay_buffer_stack_num=frames_stack, + replay_buffer_ignore_obs_next=True, + replay_buffer_save_only_last_obs=True, + ) + + env_factory = AtariEnvFactory( + task, + experiment_config.seed, + sampling_config, + frames_stack, + scale=scale_obs, + ) + + class TrainEpochCallback(TrainerEpochCallback): + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + policy: DQNPolicy = context.policy + logger = context.logger.logger + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) + else: + eps = eps_train_final + policy.set_eps(eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) + + class TestEpochCallback(TrainerEpochCallback): + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + policy: DQNPolicy = context.policy + policy.set_eps(eps_test) + + builder = ( + DQNExperimentBuilder(experiment_config, env_factory, sampling_config) + .with_dqn_params( + DQNParams( + discount_factor=gamma, + estimation_step=n_step, + lr=lr, + target_update_freq=target_update_freq, + ), + ) + .with_critic_factory(CriticFactoryAtariDQN()) + .with_trainer_epoch_callback_train(TrainEpochCallback()) + .with_trainer_epoch_callback_test(TestEpochCallback()) + .with_trainer_stop_callback(AtariStopCallback(task)) + ) + if icm_lr_scale > 0: + builder.with_policy_wrapper_factory( + PolicyWrapperFactoryIntrinsicCuriosity( + FeatureNetFactoryDQN(), + [512], + lr, + icm_lr_scale, + icm_reward_scale, + icm_forward_loss_weight, + ), + ) + + experiment = builder.build() + experiment.run(log_name) + + +if __name__ == "__main__": + logging.run_main(lambda: CLI(main)) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 7c18b35..0d72fad 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -8,6 +8,7 @@ from torch import nn from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice +from tianshou.highlevel.module.critic import CriticFactory from tianshou.utils.net.common import BaseActor from tianshou.utils.net.discrete import Actor, NoisyLinear @@ -226,6 +227,21 @@ class QRDQN(DQN): return obs, state +class CriticFactoryAtariDQN(CriticFactory): + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + ) -> torch.nn.Module: + assert use_action + return DQN( + *envs.get_observation_shape(), + envs.get_action_shape(), + device=device, + ).to(device) + + class ActorFactoryAtariDQN(ActorFactory): def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool): self.hidden_size = hidden_size diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 12733bb..b02ff07 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -10,7 +10,7 @@ from examples.atari.atari_network import ( ActorFactoryAtariDQN, FeatureNetFactoryDQN, ) -from examples.atari.atari_wrapper import AtariEnvFactory +from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.experiment import ( PPOExperimentBuilder, @@ -98,6 +98,7 @@ def main( ) .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs)) .with_critic_factory_use_actor() + .with_trainer_stop_callback(AtariStopCallback(task)) ) if icm_lr_scale > 0: builder.with_policy_wrapper_factory( diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 347294a..13384d6 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -11,6 +11,7 @@ import numpy as np from tianshou.env import ShmemVectorEnv from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory +from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext try: import envpool @@ -374,11 +375,19 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): class AtariEnvFactory(EnvFactory): - def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig, frame_stack: int): + def __init__( + self, + task: str, + seed: int, + sampling_config: RLSamplingConfig, + frame_stack: int, + scale: int = 0, + ): self.task = task self.sampling_config = sampling_config self.seed = seed self.frame_stack = frame_stack + self.scale = scale def create_envs(self, config=None) -> DiscreteEnvironments: env, train_envs, test_envs = make_atari_env( @@ -386,7 +395,20 @@ class AtariEnvFactory(EnvFactory): seed=self.seed, training_num=self.sampling_config.num_train_envs, test_num=self.sampling_config.num_test_envs, - scale=0, + scale=self.scale, frame_stack=self.frame_stack, ) return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) + + +class AtariStopCallback(TrainerStopCallback): + def __init__(self, task: str): + self.task = task + + def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: + env = context.envs.env + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in self.task: + return mean_rewards >= 20 + return False diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index b97f5ed..88b2fc1 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from typing import Generic, TypeVar +import gymnasium import torch from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer @@ -23,6 +24,7 @@ from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.policy_params import ( A2CParams, DDPGParams, + DQNParams, Params, ParamTransformerData, PPOParams, @@ -30,10 +32,12 @@ from tianshou.highlevel.params.policy_params import ( TD3Params, ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory +from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.policy import ( A2CPolicy, BasePolicy, DDPGPolicy, + DQNPolicy, PPOPolicy, SACPolicy, TD3Policy, @@ -54,6 +58,7 @@ class AgentFactory(ABC, ToStringMixin): self.sampling_config = sampling_config self.optim_factory = optim_factory self.policy_wrapper_factory: PolicyWrapperFactory | None = None + self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() def create_train_test_collector(self, policy: BasePolicy, envs: Environments): buffer_size = self.sampling_config.buffer_size @@ -85,6 +90,9 @@ class AgentFactory(ABC, ToStringMixin): ) -> None: self.policy_wrapper_factory = policy_wrapper_factory + def set_trainer_callbacks(self, callbacks: TrainerCallbacks): + self.trainer_callbacks = callbacks + @abstractmethod def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: pass @@ -145,6 +153,21 @@ class OnpolicyAgentFactory(AgentFactory, ABC): logger: Logger, ) -> OnpolicyTrainer: sampling_config = self.sampling_config + callbacks = self.trainer_callbacks + context = TrainingContext(policy, envs, logger) + train_fn = ( + callbacks.epoch_callback_train.get_trainer_fn(context) + if callbacks.epoch_callback_train + else None + ) + test_fn = ( + callbacks.epoch_callback_test.get_trainer_fn(context) + if callbacks.epoch_callback_test + else None + ) + stop_fn = ( + callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None + ) return OnpolicyTrainer( policy=policy, train_collector=train_collector, @@ -158,6 +181,9 @@ class OnpolicyAgentFactory(AgentFactory, ABC): save_best_fn=self._create_save_best_fn(envs, logger.log_path), logger=logger.logger, test_in_train=False, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, ) @@ -171,6 +197,21 @@ class OffpolicyAgentFactory(AgentFactory, ABC): logger: Logger, ) -> OffpolicyTrainer: sampling_config = self.sampling_config + callbacks = self.trainer_callbacks + context = TrainingContext(policy, envs, logger) + train_fn = ( + callbacks.epoch_callback_train.get_trainer_fn(context) + if callbacks.epoch_callback_train + else None + ) + test_fn = ( + callbacks.epoch_callback_test.get_trainer_fn(context) + if callbacks.epoch_callback_test + else None + ) + stop_fn = ( + callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None + ) return OffpolicyTrainer( policy=policy, train_collector=train_collector, @@ -184,6 +225,9 @@ class OffpolicyAgentFactory(AgentFactory, ABC): logger=logger.logger, update_per_step=sampling_config.update_per_step, test_in_train=False, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, ) @@ -195,6 +239,23 @@ class _ActorMixin: return self.actor_module_opt_factory.create_module_opt(envs, device, lr) +class _CriticMixin: + def __init__( + self, + critic_factory: CriticFactory, + optim_factory: OptimizerFactory, + critic_use_action: bool, + ): + self.critic_module_opt_factory = CriticModuleOptFactory( + critic_factory, + optim_factory, + critic_use_action, + ) + + def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: + return self.critic_module_opt_factory.create_module_opt(envs, device, lr) + + class _ActorCriticMixin: """Mixin for agents that use an ActorCritic module with a single optimizer.""" @@ -241,7 +302,7 @@ class _ActorCriticMixin: return ActorCriticModuleOpt(actor_critic, optim) -class _ActorAndCriticMixin(_ActorMixin): +class _ActorAndCriticMixin(_ActorMixin, _CriticMixin): def __init__( self, actor_factory: ActorFactory, @@ -249,15 +310,8 @@ class _ActorAndCriticMixin(_ActorMixin): optim_factory: OptimizerFactory, critic_use_action: bool, ): - super().__init__(actor_factory, optim_factory) - self.critic_module_opt_factory = CriticModuleOptFactory( - critic_factory, - optim_factory, - critic_use_action, - ) - - def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt: - return self.critic_module_opt_factory.create_module_opt(envs, device, lr) + _ActorMixin.__init__(self, actor_factory, optim_factory) + _CriticMixin.__init__(self, critic_factory, optim_factory, critic_use_action) class _ActorAndDualCriticsMixin(_ActorAndCriticMixin): @@ -385,6 +439,42 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): return self.create_actor_critic_module_opt(envs, device, self.params.lr) +class DQNAgentFactory(OffpolicyAgentFactory): + def __init__( + self, + params: DQNParams, + sampling_config: RLSamplingConfig, + critic_factory: CriticFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory) + self.params = params + self.critic_factory = critic_factory + self.optim_factory = optim_factory + + def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + critic = self.critic_factory.create_module(envs, device, use_action=True) + optim = self.optim_factory.create_optimizer(critic, self.params.lr) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim=optim, + optim_factory=self.optim_factory, + ), + ) + envs.get_type().assert_discrete(self) + # noinspection PyTypeChecker + action_space: gymnasium.spaces.Discrete = envs.get_action_space() + return DQNPolicy( + model=critic, + optim=optim, + action_space=action_space, + observation_space=envs.get_observation_space(), + **kwargs, + ) + + class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin): def __init__( self, diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 500dfe8..6937fc2 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -14,10 +14,14 @@ class RLSamplingConfig: buffer_size: int = 4096 step_per_collect: int = 2048 repeat_per_collect: int | None = 10 - update_per_step: int = 1 + update_per_step: float = 1.0 + """ + Only used in off-policy algorithms. + How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer). + """ start_timesteps: int = 0 start_timesteps_random: bool = False - # TODO can we set the parameters below more intelligently? Perhaps based on env. representation? + # TODO can we set the parameters below intelligently? Perhaps based on env. representation? replay_buffer_ignore_obs_next: bool = False replay_buffer_save_only_last_obs: bool = False replay_buffer_stack_num: int = 1 diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 3d86203..b93a758 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -13,6 +13,7 @@ from tianshou.highlevel.agent import ( A2CAgentFactory, AgentFactory, DDPGAgentFactory, + DQNAgentFactory, PPOAgentFactory, SACAgentFactory, TD3AgentFactory, @@ -30,12 +31,18 @@ from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.params.policy_params import ( A2CParams, DDPGParams, + DQNParams, PPOParams, SACParams, TD3Params, ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.persistence import PersistableConfigProtocol +from tianshou.highlevel.trainer import ( + TrainerCallbacks, + TrainerEpochCallback, + TrainerStopCallback, +) from tianshou.policy import BasePolicy from tianshou.trainer import BaseTrainer from tianshou.utils.string import ToStringMixin @@ -160,6 +167,7 @@ class RLExperimentBuilder: self._optim_factory: OptimizerFactory | None = None self._env_config: PersistableConfigProtocol | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None + self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() def with_env_config(self, config: PersistableConfigProtocol) -> Self: self._env_config = config @@ -193,6 +201,18 @@ class RLExperimentBuilder: self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) return self + def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self: + self._trainer_callbacks.epoch_callback_train = callback + return self + + def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self: + self._trainer_callbacks.epoch_callback_test = callback + return self + + def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self: + self._trainer_callbacks.stop_callback = callback + return self + @abstractmethod def _create_agent_factory(self) -> AgentFactory: pass @@ -205,6 +225,7 @@ class RLExperimentBuilder: def build(self) -> RLExperiment: agent_factory = self._create_agent_factory() + agent_factory.set_trainer_callbacks(self._trainer_callbacks) if self._policy_wrapper_factory: agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) experiment = RLExperiment( @@ -302,21 +323,24 @@ class _BuilderMixinCriticsFactory: class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): def __init__(self): super().__init__(1) - self._critic_use_actor_module = False - def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: - self: TBuilder | "_BuilderMixinSingleCriticFactory" + def with_critic_factory(self, critic_factory: CriticFactory) -> Self: self._with_critic_factory(0, critic_factory) return self def with_critic_factory_default( - self: TBuilder, + self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, - ) -> TBuilder: - self: TBuilder | "_BuilderMixinSingleCriticFactory" + ) -> Self: self._with_critic_factory_default(0, hidden_sizes) return self + +class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory): + def __init__(self): + super().__init__() + self._critic_use_actor_module = False + def with_critic_factory_use_actor(self) -> Self: """Makes the critic use the same network as the actor.""" self._critic_use_actor_module = True @@ -372,7 +396,7 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): class A2CExperimentBuilder( RLExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, - _BuilderMixinSingleCriticFactory, + _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, @@ -383,7 +407,7 @@ class A2CExperimentBuilder( ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - _BuilderMixinSingleCriticFactory.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) self._params: A2CParams = A2CParams() self._env_config = env_config @@ -406,7 +430,7 @@ class A2CExperimentBuilder( class PPOExperimentBuilder( RLExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, - _BuilderMixinSingleCriticFactory, + _BuilderMixinSingleCriticCanUseActorFactory, ): def __init__( self, @@ -416,7 +440,7 @@ class PPOExperimentBuilder( ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - _BuilderMixinSingleCriticFactory.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) self._params: PPOParams = PPOParams() def with_ppo_params(self, params: PPOParams) -> Self: @@ -435,9 +459,8 @@ class PPOExperimentBuilder( ) -class DDPGExperimentBuilder( +class DQNExperimentBuilder( RLExperimentBuilder, - _BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinSingleCriticFactory, ): def __init__( @@ -445,13 +468,40 @@ class DDPGExperimentBuilder( experiment_config: RLExperimentConfig, env_factory: EnvFactory, sampling_config: RLSamplingConfig, - env_config: PersistableConfigProtocol | None = None, + ): + super().__init__(experiment_config, env_factory, sampling_config) + _BuilderMixinSingleCriticFactory.__init__(self) + self._params: DQNParams = DQNParams() + + def with_dqn_params(self, params: DQNParams) -> Self: + self._params = params + return self + + @abstractmethod + def _create_agent_factory(self) -> AgentFactory: + return DQNAgentFactory( + self._params, + self._sampling_config, + self._get_critic_factory(0), + self._get_optim_factory(), + ) + + +class DDPGExperimentBuilder( + RLExperimentBuilder, + _BuilderMixinActorFactory_ContinuousDeterministic, + _BuilderMixinSingleCriticCanUseActorFactory, +): + def __init__( + self, + experiment_config: RLExperimentConfig, + env_factory: EnvFactory, + sampling_config: RLSamplingConfig, ): super().__init__(experiment_config, env_factory, sampling_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) - _BuilderMixinSingleCriticFactory.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self) self._params: DDPGParams = DDPGParams() - self._env_config = env_config def with_ddpg_params(self, params: DDPGParams) -> Self: self._params = params diff --git a/tianshou/highlevel/module/core.py b/tianshou/highlevel/module/core.py index f0a2b7b..45c8836 100644 --- a/tianshou/highlevel/module/core.py +++ b/tianshou/highlevel/module/core.py @@ -36,7 +36,9 @@ class ModuleFactory(ToStringMixin, ABC): pass -class ModuleFactoryNet(ModuleFactory): +class ModuleFactoryNet( + ModuleFactory, +): # TODO This is unused and broken; use it in ActorFactory* and so on? def __init__(self, hidden_sizes: int | Sequence[int]): self.hidden_sizes = hidden_sizes diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index e2e7257..234f91c 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -356,6 +356,21 @@ class SACParams(Params, ParamsMixinActorAndDualCritics): return transformers +@dataclass +class DQNParams(Params, ParamsMixinLearningRateWithScheduler): + discount_factor: float = 0.99 + estimation_step: int = 1 + target_update_freq: int = 0 + reward_normalization: bool = False + is_double: bool = True + clip_loss_grad: bool = False + + def _get_param_transformers(self): + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) + return transformers + + @dataclass class DDPGParams(Params, ParamsMixinActorAndCritic): tau: float = 0.005 diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py new file mode 100644 index 0000000..96a8663 --- /dev/null +++ b/tianshou/highlevel/trainer.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import TypeVar + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.logger import Logger +from tianshou.policy import BasePolicy +from tianshou.utils.string import ToStringMixin + +TPolicy = TypeVar("TPolicy", bound=BasePolicy) + + +class TrainingContext: + def __init__(self, policy: TPolicy, envs: Environments, logger: Logger): + self.policy = policy + self.envs = envs + self.logger = logger + + +class TrainerEpochCallback(ToStringMixin, ABC): + """Callback which is called at the beginning of each epoch.""" + + @abstractmethod + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + pass + + def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]: + def fn(epoch, env_step): + return self.callback(epoch, env_step, context) + + return fn + + +class TrainerStopCallback(ToStringMixin, ABC): + """Callback indicating whether training should stop.""" + + @abstractmethod + def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: + """:param mean_rewards: the average undiscounted returns of the testing result + :return: True if the goal has been reached and training should stop, False otherwise + """ + + def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]: + def fn(mean_rewards: float): + return self.should_stop(mean_rewards, context) + + return fn + + +@dataclass +class TrainerCallbacks: + epoch_callback_train: TrainerEpochCallback | None = None + epoch_callback_test: TrainerEpochCallback | None = None + stop_callback: TrainerStopCallback | None = None