diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 8aed0fa..63d5a51 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -19,7 +19,11 @@ from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.params.policy_wrapper import ( PolicyWrapperFactoryIntrinsicCuriosity, ) -from tianshou.highlevel.trainer import TrainerEpochCallback, TrainingContext +from tianshou.highlevel.trainer import ( + TrainerEpochCallbackTest, + TrainerEpochCallbackTrain, + TrainingContext, +) from tianshou.policy import DQNPolicy from tianshou.utils import logging @@ -75,7 +79,7 @@ def main( scale=scale_obs, ) - class TrainEpochCallback(TrainerEpochCallback): + class TrainEpochCallback(TrainerEpochCallbackTrain): def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: policy: DQNPolicy = context.policy logger = context.logger.logger @@ -88,7 +92,7 @@ def main( if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - class TestEpochCallback(TrainerEpochCallback): + class TestEpochCallback(TrainerEpochCallbackTest): def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: policy: DQNPolicy = context.policy policy.set_eps(eps_test) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index bd1ca70..25d9a68 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -13,7 +13,9 @@ from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, ) -from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryIndependentGaussians +from tianshou.highlevel.params.dist_fn import ( + DistributionFunctionFactoryIndependentGaussians, +) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams from tianshou.utils import logging diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 6fa308c..2c2206f 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Generic, TypeVar +from os import PathLike +from typing import Any, Generic, TypeVar, cast import gymnasium import torch @@ -60,9 +61,14 @@ class AgentFactory(ABC, ToStringMixin): self.policy_wrapper_factory: PolicyWrapperFactory | None = None self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() - def create_train_test_collector(self, policy: BasePolicy, envs: Environments): + def create_train_test_collector( + self, + policy: BasePolicy, + envs: Environments, + ) -> tuple[Collector, Collector]: buffer_size = self.sampling_config.buffer_size train_envs = envs.train_envs + buffer: ReplayBuffer if len(train_envs) > 1: buffer = VectorReplayBuffer( buffer_size, @@ -90,7 +96,7 @@ class AgentFactory(ABC, ToStringMixin): ) -> None: self.policy_wrapper_factory = policy_wrapper_factory - def set_trainer_callbacks(self, callbacks: TrainerCallbacks): + def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: self.trainer_callbacks = callbacks @abstractmethod @@ -122,7 +128,12 @@ class AgentFactory(ABC, ToStringMixin): return save_best_fn @staticmethod - def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice): + def load_checkpoint( + policy: torch.nn.Module, + path: str | PathLike, + envs: Environments, + device: TDevice, + ) -> None: ckpt = torch.load(path, map_location=device) policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL]) if envs.train_envs: @@ -280,6 +291,7 @@ class _ActorCriticMixin: lr: float, ) -> ActorCriticModuleOpt: actor = self.actor_factory.create_module(envs, device) + critic: torch.nn.Module if self.critic_use_actor_module: if self.critic_use_action: raise ValueError( @@ -375,7 +387,7 @@ class ActorCriticAgentFactory( def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: pass - def _create_kwargs(self, envs: Environments, device: TDevice): + def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: actor_critic = self._create_actor_critic(envs, device) kwargs = self.params.create_kwargs( ParamTransformerData( @@ -468,8 +480,7 @@ class DQNAgentFactory(OffpolicyAgentFactory): ), ) envs.get_type().assert_discrete(self) - # noinspection PyTypeChecker - action_space: gymnasium.spaces.Discrete = envs.get_action_space() + action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) return DQNPolicy( model=model, optim=optim, diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 5b28bfc..c48fa90 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -1,37 +1,38 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from enum import Enum -from typing import Any +from typing import Any, TypeAlias import gymnasium as gym from tianshou.env import BaseVectorEnv from tianshou.highlevel.persistence import PersistableConfigProtocol +from tianshou.utils.net.common import TActionShape -TShape = int | Sequence[int] +TObservationShape: TypeAlias = int | Sequence[int] class EnvType(Enum): CONTINUOUS = "continuous" DISCRETE = "discrete" - def is_discrete(self): + def is_discrete(self) -> bool: return self == EnvType.DISCRETE - def is_continuous(self): + def is_continuous(self) -> bool: return self == EnvType.CONTINUOUS - def assert_continuous(self, requiring_entity: Any): + def assert_continuous(self, requiring_entity: Any) -> None: if not self.is_continuous(): raise AssertionError(f"{requiring_entity} requires continuous environments") - def assert_discrete(self, requiring_entity: Any): + def assert_discrete(self, requiring_entity: Any) -> None: if not self.is_discrete(): raise AssertionError(f"{requiring_entity} requires discrete environments") class Environments(ABC): - def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env self.train_envs = train_envs self.test_envs = test_envs @@ -43,11 +44,11 @@ class Environments(ABC): } @abstractmethod - def get_action_shape(self) -> TShape: + def get_action_shape(self) -> TActionShape: pass @abstractmethod - def get_observation_shape(self) -> TShape: + def get_observation_shape(self) -> TObservationShape: pass def get_action_space(self) -> gym.Space: @@ -62,11 +63,11 @@ class Environments(ABC): class ContinuousEnvironments(Environments): - def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) - def info(self): + def info(self) -> dict[str, Any]: d = super().info() d["max_action"] = self.max_action return d @@ -80,17 +81,17 @@ class ContinuousEnvironments(Environments): "Only environments with continuous action space are supported here. " f"But got env with action space: {env.action_space.__class__}.", ) - state_shape = env.observation_space.shape or env.observation_space.n + state_shape = env.observation_space.shape or env.observation_space.n # type: ignore if not state_shape: raise ValueError("Observation space shape is not defined") action_shape = env.action_space.shape max_action = env.action_space.high[0] return state_shape, action_shape, max_action - def get_action_shape(self) -> TShape: + def get_action_shape(self) -> TActionShape: return self.action_shape - def get_observation_shape(self) -> TShape: + def get_observation_shape(self) -> TObservationShape: return self.state_shape def get_type(self) -> EnvType: @@ -98,15 +99,15 @@ class ContinuousEnvironments(Environments): class DiscreteEnvironments(Environments): - def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) - self.observation_shape = env.observation_space.shape or env.observation_space.n - self.action_shape = env.action_space.shape or env.action_space.n + self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore + self.action_shape = env.action_space.shape or env.action_space.n # type: ignore - def get_action_shape(self) -> TShape: + def get_action_shape(self) -> TActionShape: return self.action_shape - def get_observation_shape(self) -> TShape: + def get_observation_shape(self) -> TObservationShape: return self.observation_shape def get_type(self) -> EnvType: diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 7f64c53..b303430 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -40,7 +40,8 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.trainer import ( TrainerCallbacks, - TrainerEpochCallback, + TrainerEpochCallbackTest, + TrainerEpochCallbackTrain, TrainerStopCallback, ) from tianshou.policy import BasePolicy @@ -135,24 +136,29 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): result = trainer.run() pprint(result) # TODO logging + render = self.config.render + if render is None: + render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode? self._watch_agent( self.config.watch_num_episodes, policy, test_collector, - self.config.render, + render, ) @staticmethod - def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None: + def _watch_agent( + num_episodes: int, + policy: BasePolicy, + test_collector: Collector, + render: float, + ) -> None: policy.eval() test_collector.reset() result = test_collector.collect(n_episode=num_episodes, render=render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') -TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder") - - class ExperimentBuilder: def __init__( self, @@ -177,7 +183,7 @@ class ExperimentBuilder: self._env_config = config return self - def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder: + def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: self._logger_factory = logger_factory return self @@ -185,16 +191,16 @@ class ExperimentBuilder: self._policy_wrapper_factory = policy_wrapper_factory return self - def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder: + def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: self._optim_factory = optim_factory return self def with_optim_factory_default( - self: TBuilder, - betas=(0.9, 0.999), - eps=1e-08, - weight_decay=0, - ) -> TBuilder: + self, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-08, + weight_decay: float = 0, + ) -> Self: """Configures the use of the default optimizer, Adam, with the given parameters. :param betas: coefficients used for computing running averages of gradient and its square @@ -205,11 +211,11 @@ class ExperimentBuilder: self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) return self - def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self: + def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self: self._trainer_callbacks.epoch_callback_train = callback return self - def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self: + def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self: self._trainer_callbacks.epoch_callback_test = callback return self @@ -232,7 +238,7 @@ class ExperimentBuilder: agent_factory.set_trainer_callbacks(self._trainer_callbacks) if self._policy_wrapper_factory: agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) - experiment = Experiment( + experiment: Experiment = Experiment( self._config, self._env_factory, agent_factory, @@ -248,18 +254,16 @@ class _BuilderMixinActorFactory: self._continuous_actor_type = continuous_actor_type self._actor_factory: ActorFactory | None = None - def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder: - self: TBuilder | _BuilderMixinActorFactory + def with_actor_factory(self, actor_factory: ActorFactory) -> Self: self._actor_factory = actor_factory return self def _with_actor_factory_default( - self: TBuilder, + self, hidden_sizes: Sequence[int], - continuous_unbounded=False, - continuous_conditioned_sigma=False, - ) -> TBuilder: - self: TBuilder | _BuilderMixinActorFactory + continuous_unbounded: bool = False, + continuous_conditioned_sigma: bool = False, + ) -> Self: self._actor_factory = ActorFactoryDefault( self._continuous_actor_type, hidden_sizes, @@ -268,7 +272,7 @@ class _BuilderMixinActorFactory: ) return self - def _get_actor_factory(self): + def _get_actor_factory(self) -> ActorFactory: if self._actor_factory is None: return ActorFactoryDefault(self._continuous_actor_type) else: @@ -278,14 +282,14 @@ class _BuilderMixinActorFactory: class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" - def __init__(self): + def __init__(self) -> None: super().__init__(ContinuousActorType.GAUSSIAN) def with_actor_factory_default( self, hidden_sizes: Sequence[int], - continuous_unbounded=False, - continuous_conditioned_sigma=False, + continuous_unbounded: bool = False, + continuous_conditioned_sigma: bool = False, ) -> Self: return super()._with_actor_factory_default( hidden_sizes, @@ -297,7 +301,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory): """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" - def __init__(self): + def __init__(self) -> None: super().__init__(ContinuousActorType.DETERMINISTIC) def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self: @@ -308,15 +312,15 @@ class _BuilderMixinCriticsFactory: def __init__(self, num_critics: int): self._critic_factories: list[CriticFactory | None] = [None] * num_critics - def _with_critic_factory(self, idx: int, critic_factory: CriticFactory): + def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self: self._critic_factories[idx] = critic_factory return self - def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]): + def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]) -> Self: self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes) return self - def _get_critic_factory(self, idx: int): + def _get_critic_factory(self, idx: int) -> CriticFactory: factory = self._critic_factories[idx] if factory is None: return CriticFactoryDefault() @@ -325,7 +329,7 @@ class _BuilderMixinCriticsFactory: class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(1) def with_critic_factory(self, critic_factory: CriticFactory) -> Self: @@ -341,7 +345,7 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory): - def __init__(self): + def __init__(self) -> None: super().__init__() self._critic_use_actor_module = False @@ -352,11 +356,10 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(2) - def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: - self: TBuilder | "_BuilderMixinDualCriticFactory" + def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self: for i in range(len(self._critic_factories)): self._with_critic_factory(i, critic_factory) return self @@ -364,35 +367,30 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): def with_common_critic_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, - ) -> TBuilder: - self: TBuilder | "_BuilderMixinDualCriticFactory" + ) -> Self: for i in range(len(self._critic_factories)): self._with_critic_factory_default(i, hidden_sizes) return self - def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: - self: TBuilder | "_BuilderMixinDualCriticFactory" + def with_critic1_factory(self, critic_factory: CriticFactory) -> Self: self._with_critic_factory(0, critic_factory) return self def with_critic1_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, - ) -> TBuilder: - self: TBuilder | "_BuilderMixinDualCriticFactory" + ) -> Self: self._with_critic_factory_default(0, hidden_sizes) return self - def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: - self: TBuilder | "_BuilderMixinDualCriticFactory" + def with_critic2_factory(self, critic_factory: CriticFactory) -> Self: self._with_critic_factory(1, critic_factory) return self def with_critic2_factory_default( self, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, - ) -> TBuilder: - self: TBuilder | "_BuilderMixinDualCriticFactory" + ) -> Self: self._with_critic_factory_default(0, hidden_sizes) return self diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 8de2033..c913a9b 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -19,7 +19,7 @@ class Logger: class LoggerFactory(ToStringMixin, ABC): @abstractmethod - def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger: + def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger: pass @@ -48,6 +48,7 @@ class DefaultLoggerFactory(LoggerFactory): ), ), ) + logger: TLogger if self.logger_type == "wandb": logger = WandbLogger( save_interval=1, diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 2c4fa5c..967ebf2 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from enum import Enum import torch from torch import nn @@ -11,7 +12,7 @@ from tianshou.utils.net.common import BaseActor, Net from tianshou.utils.string import ToStringMixin -class ContinuousActorType: +class ContinuousActorType(Enum): GAUSSIAN = "gaussian" DETERMINISTIC = "deterministic" UNSUPPORTED = "unsupported" @@ -23,7 +24,7 @@ class ActorFactory(ToStringMixin, ABC): pass @staticmethod - def _init_linear(actor: torch.nn.Module): + def _init_linear(actor: torch.nn.Module) -> None: """Initializes linear layers of an actor module using default mechanisms. :param module: the actor module. @@ -34,7 +35,7 @@ class ActorFactory(ToStringMixin, ABC): # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details - for m in actor.mu.modules(): + for m in actor.mu.modules(): # type: ignore if isinstance(m, torch.nn.Linear): m.weight.data.copy_(0.01 * m.weight.data) @@ -48,8 +49,8 @@ class ActorFactoryDefault(ActorFactory): self, continuous_actor_type: ContinuousActorType, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, - continuous_unbounded=False, - continuous_conditioned_sigma=False, + continuous_unbounded: bool = False, + continuous_conditioned_sigma: bool = False, ): self.continuous_actor_type = continuous_actor_type self.continuous_unbounded = continuous_unbounded @@ -58,6 +59,7 @@ class ActorFactoryDefault(ActorFactory): def create_module(self, envs: Environments, device: TDevice) -> BaseActor: env_type = envs.get_type() + factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet if env_type == EnvType.CONTINUOUS: match self.continuous_actor_type: case ContinuousActorType.GAUSSIAN: @@ -103,7 +105,12 @@ class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous): class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): - def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False): + def __init__( + self, + hidden_sizes: Sequence[int], + unbounded: bool = True, + conditioned_sigma: bool = False, + ): self.hidden_sizes = hidden_sizes self.unbounded = unbounded self.conditioned_sigma = conditioned_sigma diff --git a/tianshou/highlevel/module/core.py b/tianshou/highlevel/module/core.py index 45c8836..80ca012 100644 --- a/tianshou/highlevel/module/core.py +++ b/tianshou/highlevel/module/core.py @@ -10,10 +10,10 @@ from tianshou.highlevel.env import Environments from tianshou.utils.net.common import Net from tianshou.utils.string import ToStringMixin -TDevice: TypeAlias = str | int | torch.device +TDevice: TypeAlias = str | torch.device -def init_linear_orthogonal(module: torch.nn.Module): +def init_linear_orthogonal(module: torch.nn.Module) -> None: """Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0. :param module: the module whose submodules are to be processed diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 83cb797..ad68415 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -27,11 +27,17 @@ class CriticFactoryDefault(CriticFactory): def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: env_type = envs.get_type() if env_type == EnvType.CONTINUOUS: - factory = CriticFactoryContinuousNet(self.hidden_sizes) - return factory.create_module(envs, device, use_action) + return CriticFactoryContinuousNet(self.hidden_sizes).create_module( + envs, + device, + use_action, + ) elif env_type == EnvType.DISCRETE: - factory = CriticFactoryDiscreteNet(self.hidden_sizes) - return factory.create_module(envs, device, use_action) + return CriticFactoryDiscreteNet(self.hidden_sizes).create_module( + envs, + device, + use_action, + ) else: raise ValueError(f"{env_type} not supported") diff --git a/tianshou/highlevel/module/module_opt.py b/tianshou/highlevel/module/module_opt.py index 802d305..1300a59 100644 --- a/tianshou/highlevel/module/module_opt.py +++ b/tianshou/highlevel/module/module_opt.py @@ -23,11 +23,11 @@ class ActorCriticModuleOpt: optim: torch.optim.Optimizer @property - def actor(self): + def actor(self) -> torch.nn.Module: return self.actor_critic_module.actor @property - def critic(self): + def critic(self) -> torch.nn.Module: return self.actor_critic_module.critic diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index de5c434..008321f 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Protocol import torch from torch.optim import Adam, RMSprop @@ -7,6 +7,11 @@ from torch.optim import Adam, RMSprop from tianshou.utils.string import ToStringMixin +class OptimizerWithLearningRateProtocol(Protocol): + def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer: + pass + + class OptimizerFactory(ABC, ToStringMixin): # TODO: Is it OK to assume that all optimizers have a learning rate argument? # Right now, the learning rate is typically a configuration parameter. @@ -18,7 +23,7 @@ class OptimizerFactory(ABC, ToStringMixin): class OptimizerFactoryTorch(OptimizerFactory): - def __init__(self, optim_class: Any, **kwargs): + def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any): """:param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), which will be passed the module parameters, the learning rate as `lr` and the kwargs provided. @@ -32,7 +37,12 @@ class OptimizerFactoryTorch(OptimizerFactory): class OptimizerFactoryAdam(OptimizerFactory): - def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0): + def __init__( + self, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-08, + weight_decay: float = 0, + ): self.weight_decay = weight_decay self.eps = eps self.betas = betas @@ -48,7 +58,14 @@ class OptimizerFactoryAdam(OptimizerFactory): class OptimizerFactoryRMSprop(OptimizerFactory): - def __init__(self, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False): + def __init__( + self, + alpha: float = 0.99, + eps: float = 1e-08, + weight_decay: float = 0, + momentum: float = 0, + centered: bool = False, + ): self.alpha = alpha self.momentum = momentum self.centered = centered diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index e3d1888..878ae4b 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -30,7 +30,7 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name? optim_factory: OptimizerFactory, device: TDevice, ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: - target_entropy = -np.prod(envs.get_action_shape()) + target_entropy = float(-np.prod(envs.get_action_shape())) log_alpha = torch.zeros(1, requires_grad=True, device=device) alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr) return target_entropy, log_alpha, alpha_optim diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py index 6e3d9d1..d21acca 100644 --- a/tianshou/highlevel/params/dist_fn.py +++ b/tianshou/highlevel/params/dist_fn.py @@ -19,21 +19,21 @@ class DistributionFunctionFactory(ToStringMixin, ABC): class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): def create_dist_fn(self, envs: Environments) -> TDistributionFunction: - assert envs.get_type().assert_discrete(self) + envs.get_type().assert_discrete(self) return self._dist_fn @staticmethod - def _dist_fn(p): + def _dist_fn(p: TDistParams) -> torch.distributions.Distribution: return torch.distributions.Categorical(logits=p) class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory): def create_dist_fn(self, envs: Environments) -> TDistributionFunction: - assert envs.get_type().assert_continuous(self) + envs.get_type().assert_continuous(self) return self._dist_fn @staticmethod - def _dist_fn(*p): + def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution: return torch.distributions.Independent(torch.distributions.Normal(*p), 1) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 234f91c..c2027c7 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -3,6 +3,7 @@ from dataclasses import asdict, dataclass from typing import Any, Literal, Protocol import torch +from torch.optim.lr_scheduler import LRScheduler from tianshou.exploration import BaseNoise from tianshou.highlevel.env import Environments @@ -71,7 +72,7 @@ class ParamTransformerChangeValue(ParamTransformer): def __init__(self, key: str): self.key = key - def transform(self, params: dict[str, Any], data: ParamTransformerData): + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: params[self.key] = self.change_value(params[self.key], data) @abstractmethod @@ -118,6 +119,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer): ) if lr_scheduler_factory is not None: lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) + lr_scheduler: LRScheduler | MultipleLRSchedulers | None match len(lr_schedulers): case 0: lr_scheduler = None @@ -140,6 +142,7 @@ class ParamTransformerActorAndCriticLRScheduler(ParamTransformer): self.key_scheduler = key_scheduler def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + assert data.actor is not None and data.critic1 is not None transformer = ParamTransformerMultiLRScheduler( [ (data.actor.optim, self.key_factory_actor), @@ -164,6 +167,7 @@ class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer): self.key_scheduler = key_scheduler def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + assert data.actor is not None and data.critic1 is not None and data.critic2 is not None transformer = ParamTransformerMultiLRScheduler( [ (data.actor.optim, self.key_factory_actor), @@ -247,7 +251,7 @@ class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol): lr: float = 1e-3 lr_scheduler_factory: LRSchedulerFactory | None = None - def _get_param_transformers(self): + def _get_param_transformers(self) -> list[ParamTransformer]: return [ ParamTransformerDrop("lr"), ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"), @@ -261,7 +265,7 @@ class ParamsMixinActorAndCritic(GetParamTransformersProtocol): actor_lr_scheduler_factory: LRSchedulerFactory | None = None critic_lr_scheduler_factory: LRSchedulerFactory | None = None - def _get_param_transformers(self): + def _get_param_transformers(self) -> list[ParamTransformer]: return [ ParamTransformerDrop("actor_lr", "critic_lr"), ParamTransformerActorAndCriticLRScheduler( @@ -325,7 +329,7 @@ class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): critic1_lr_scheduler_factory: LRSchedulerFactory | None = None critic2_lr_scheduler_factory: LRSchedulerFactory | None = None - def _get_param_transformers(self): + def _get_param_transformers(self) -> list[ParamTransformer]: return [ ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), ParamTransformerActorDualCriticsLRScheduler( @@ -348,7 +352,7 @@ class SACParams(Params, ParamsMixinActorAndDualCritics): action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" - def _get_param_transformers(self): + def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) transformers.append(ParamTransformerAutoAlpha("alpha")) @@ -365,7 +369,7 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler): is_double: bool = True clip_loss_grad: bool = False - def _get_param_transformers(self): + def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) return transformers @@ -380,7 +384,7 @@ class DDPGParams(Params, ParamsMixinActorAndCritic): action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" - def _get_param_transformers(self): + def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) transformers.append(ParamTransformerNoiseFactory("exploration_noise")) @@ -399,7 +403,7 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics): action_scaling: bool = True action_bound_method: Literal["clip"] | None = "clip" - def _get_param_transformers(self): + def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) transformers.append(ParamTransformerNoiseFactory("exploration_noise")) diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 9c9c87a..09e5065 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -36,7 +36,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity( lr: float, lr_scale: float, reward_scale: float, - forward_loss_weight, + forward_loss_weight: float, ): self.feature_net_factory = feature_net_factory self.hidden_sizes = hidden_sizes @@ -54,6 +54,8 @@ class PolicyWrapperFactoryIntrinsicCuriosity( ) -> ICMPolicy: feature_net = self.feature_net_factory.create_module(envs, device) action_dim = envs.get_action_shape() + if type(action_dim) != int: + raise ValueError(f"Environment action shape must be an integer, got {action_dim}") feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net.module, diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 96a8663..876752b 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -18,7 +18,7 @@ class TrainingContext: self.logger = logger -class TrainerEpochCallback(ToStringMixin, ABC): +class TrainerEpochCallbackTrain(ToStringMixin, ABC): """Callback which is called at the beginning of each epoch.""" @abstractmethod @@ -26,7 +26,21 @@ class TrainerEpochCallback(ToStringMixin, ABC): pass def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]: - def fn(epoch, env_step): + def fn(epoch: int, env_step: int) -> None: + return self.callback(epoch, env_step, context) + + return fn + + +class TrainerEpochCallbackTest(ToStringMixin, ABC): + """Callback which is called at the beginning of each epoch.""" + + @abstractmethod + def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: + pass + + def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]: + def fn(epoch: int, env_step: int | None) -> None: return self.callback(epoch, env_step, context) return fn @@ -42,7 +56,7 @@ class TrainerStopCallback(ToStringMixin, ABC): """ def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]: - def fn(mean_rewards: float): + def fn(mean_rewards: float) -> bool: return self.should_stop(mean_rewards, context) return fn @@ -50,6 +64,6 @@ class TrainerStopCallback(ToStringMixin, ABC): @dataclass class TrainerCallbacks: - epoch_callback_train: TrainerEpochCallback | None = None - epoch_callback_test: TrainerEpochCallback | None = None + epoch_callback_train: TrainerEpochCallbackTrain | None = None + epoch_callback_test: TrainerEpochCallbackTest | None = None stop_callback: TrainerStopCallback | None = None diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 47a139d..cb24278 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Callable -from typing import Any, Literal, cast +from typing import Any, Literal, TypeAlias, cast import gymnasium as gym import numpy as np @@ -17,7 +17,7 @@ from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.utils import RunningMeanStd -TDistParams = torch.Tensor | tuple[torch.Tensor] +TDistParams: TypeAlias = torch.Tensor | [torch.Tensor, torch.Tensor] class PGPolicy(BasePolicy): diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py index 1f7a1f2..59890c7 100644 --- a/tianshou/utils/lr_scheduler.py +++ b/tianshou/utils/lr_scheduler.py @@ -15,7 +15,7 @@ class MultipleLRSchedulers: policy = PPOPolicy(..., lr_scheduler=scheduler) """ - def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR): + def __init__(self, *args: torch.optim.lr_scheduler.LRScheduler): self.schedulers = args def step(self) -> None: diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 4c263ed..fa6ec00 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import Any, no_type_check +from typing import Any, TypeAlias, no_type_check import numpy as np import torch @@ -11,6 +11,7 @@ from tianshou.data.types import RecurrentStateBatch ModuleType = type[nn.Module] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] +TActionShape: TypeAlias = Sequence[int] | int def miniblock( diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index f6009a4..d70ff02 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -6,7 +6,7 @@ import numpy as np import torch from torch import nn -from tianshou.utils.net.common import MLP, BaseActor +from tianshou.utils.net.common import MLP, BaseActor, TActionShape SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -40,7 +40,7 @@ class Actor(BaseActor): def __init__( self, preprocess_net: nn.Module, - action_shape: Sequence[int], + action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, device: str | int | torch.device = "cpu", @@ -180,7 +180,7 @@ class ActorProb(BaseActor): def __init__( self, preprocess_net: nn.Module, - action_shape: Sequence[int], + action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, device: str | int | torch.device = "cpu", diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index a5bc6af..083cac5 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import nn from tianshou.data import Batch, to_torch -from tianshou.utils.net.common import MLP, BaseActor +from tianshou.utils.net.common import MLP, BaseActor, TActionShape class Actor(BaseActor): @@ -39,7 +39,7 @@ class Actor(BaseActor): def __init__( self, preprocess_net: nn.Module, - action_shape: Sequence[int], + action_shape: TActionShape, hidden_sizes: Sequence[int] = (), softmax_output: bool = True, preprocess_net_output_dim: int | None = None,