Improve type annotations, fix type issues and add checks

This commit is contained in:
Dominik Jain 2023-10-09 17:22:52 +02:00
parent e6716326bd
commit a161a9cf58
21 changed files with 191 additions and 123 deletions

View File

@ -19,7 +19,11 @@ from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import ( from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity, PolicyWrapperFactoryIntrinsicCuriosity,
) )
from tianshou.highlevel.trainer import TrainerEpochCallback, TrainingContext from tianshou.highlevel.trainer import (
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainingContext,
)
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.utils import logging from tianshou.utils import logging
@ -75,7 +79,7 @@ def main(
scale=scale_obs, scale=scale_obs,
) )
class TrainEpochCallback(TrainerEpochCallback): class TrainEpochCallback(TrainerEpochCallbackTrain):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy policy: DQNPolicy = context.policy
logger = context.logger.logger logger = context.logger.logger
@ -88,7 +92,7 @@ def main(
if env_step % 1000 == 0: if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps}) logger.write("train/env_step", env_step, {"train/eps": eps})
class TestEpochCallback(TrainerEpochCallback): class TestEpochCallback(TrainerEpochCallbackTest):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy policy: DQNPolicy = context.policy
policy.set_eps(eps_test) policy.set_eps(eps_test)

View File

@ -13,7 +13,9 @@ from tianshou.highlevel.experiment import (
ExperimentConfig, ExperimentConfig,
PPOExperimentBuilder, PPOExperimentBuilder,
) )
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryIndependentGaussians from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging from tianshou.utils import logging

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import Generic, TypeVar from os import PathLike
from typing import Any, Generic, TypeVar, cast
import gymnasium import gymnasium
import torch import torch
@ -60,9 +61,14 @@ class AgentFactory(ABC, ToStringMixin):
self.policy_wrapper_factory: PolicyWrapperFactory | None = None self.policy_wrapper_factory: PolicyWrapperFactory | None = None
self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
def create_train_test_collector(self, policy: BasePolicy, envs: Environments): def create_train_test_collector(
self,
policy: BasePolicy,
envs: Environments,
) -> tuple[Collector, Collector]:
buffer_size = self.sampling_config.buffer_size buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs train_envs = envs.train_envs
buffer: ReplayBuffer
if len(train_envs) > 1: if len(train_envs) > 1:
buffer = VectorReplayBuffer( buffer = VectorReplayBuffer(
buffer_size, buffer_size,
@ -90,7 +96,7 @@ class AgentFactory(ABC, ToStringMixin):
) -> None: ) -> None:
self.policy_wrapper_factory = policy_wrapper_factory self.policy_wrapper_factory = policy_wrapper_factory
def set_trainer_callbacks(self, callbacks: TrainerCallbacks): def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None:
self.trainer_callbacks = callbacks self.trainer_callbacks = callbacks
@abstractmethod @abstractmethod
@ -122,7 +128,12 @@ class AgentFactory(ABC, ToStringMixin):
return save_best_fn return save_best_fn
@staticmethod @staticmethod
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice): def load_checkpoint(
policy: torch.nn.Module,
path: str | PathLike,
envs: Environments,
device: TDevice,
) -> None:
ckpt = torch.load(path, map_location=device) ckpt = torch.load(path, map_location=device)
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL]) policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
if envs.train_envs: if envs.train_envs:
@ -280,6 +291,7 @@ class _ActorCriticMixin:
lr: float, lr: float,
) -> ActorCriticModuleOpt: ) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device) actor = self.actor_factory.create_module(envs, device)
critic: torch.nn.Module
if self.critic_use_actor_module: if self.critic_use_actor_module:
if self.critic_use_action: if self.critic_use_action:
raise ValueError( raise ValueError(
@ -375,7 +387,7 @@ class ActorCriticAgentFactory(
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt: def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
pass pass
def _create_kwargs(self, envs: Environments, device: TDevice): def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
actor_critic = self._create_actor_critic(envs, device) actor_critic = self._create_actor_critic(envs, device)
kwargs = self.params.create_kwargs( kwargs = self.params.create_kwargs(
ParamTransformerData( ParamTransformerData(
@ -468,8 +480,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
), ),
) )
envs.get_type().assert_discrete(self) envs.get_type().assert_discrete(self)
# noinspection PyTypeChecker action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space())
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
return DQNPolicy( return DQNPolicy(
model=model, model=model,
optim=optim, optim=optim,

View File

@ -1,37 +1,38 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum from enum import Enum
from typing import Any from typing import Any, TypeAlias
import gymnasium as gym import gymnasium as gym
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.utils.net.common import TActionShape
TShape = int | Sequence[int] TObservationShape: TypeAlias = int | Sequence[int]
class EnvType(Enum): class EnvType(Enum):
CONTINUOUS = "continuous" CONTINUOUS = "continuous"
DISCRETE = "discrete" DISCRETE = "discrete"
def is_discrete(self): def is_discrete(self) -> bool:
return self == EnvType.DISCRETE return self == EnvType.DISCRETE
def is_continuous(self): def is_continuous(self) -> bool:
return self == EnvType.CONTINUOUS return self == EnvType.CONTINUOUS
def assert_continuous(self, requiring_entity: Any): def assert_continuous(self, requiring_entity: Any) -> None:
if not self.is_continuous(): if not self.is_continuous():
raise AssertionError(f"{requiring_entity} requires continuous environments") raise AssertionError(f"{requiring_entity} requires continuous environments")
def assert_discrete(self, requiring_entity: Any): def assert_discrete(self, requiring_entity: Any) -> None:
if not self.is_discrete(): if not self.is_discrete():
raise AssertionError(f"{requiring_entity} requires discrete environments") raise AssertionError(f"{requiring_entity} requires discrete environments")
class Environments(ABC): class Environments(ABC):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env self.env = env
self.train_envs = train_envs self.train_envs = train_envs
self.test_envs = test_envs self.test_envs = test_envs
@ -43,11 +44,11 @@ class Environments(ABC):
} }
@abstractmethod @abstractmethod
def get_action_shape(self) -> TShape: def get_action_shape(self) -> TActionShape:
pass pass
@abstractmethod @abstractmethod
def get_observation_shape(self) -> TShape: def get_observation_shape(self) -> TObservationShape:
pass pass
def get_action_space(self) -> gym.Space: def get_action_space(self) -> gym.Space:
@ -62,11 +63,11 @@ class Environments(ABC):
class ContinuousEnvironments(Environments): class ContinuousEnvironments(Environments):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs) super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
def info(self): def info(self) -> dict[str, Any]:
d = super().info() d = super().info()
d["max_action"] = self.max_action d["max_action"] = self.max_action
return d return d
@ -80,17 +81,17 @@ class ContinuousEnvironments(Environments):
"Only environments with continuous action space are supported here. " "Only environments with continuous action space are supported here. "
f"But got env with action space: {env.action_space.__class__}.", f"But got env with action space: {env.action_space.__class__}.",
) )
state_shape = env.observation_space.shape or env.observation_space.n state_shape = env.observation_space.shape or env.observation_space.n # type: ignore
if not state_shape: if not state_shape:
raise ValueError("Observation space shape is not defined") raise ValueError("Observation space shape is not defined")
action_shape = env.action_space.shape action_shape = env.action_space.shape
max_action = env.action_space.high[0] max_action = env.action_space.high[0]
return state_shape, action_shape, max_action return state_shape, action_shape, max_action
def get_action_shape(self) -> TShape: def get_action_shape(self) -> TActionShape:
return self.action_shape return self.action_shape
def get_observation_shape(self) -> TShape: def get_observation_shape(self) -> TObservationShape:
return self.state_shape return self.state_shape
def get_type(self) -> EnvType: def get_type(self) -> EnvType:
@ -98,15 +99,15 @@ class ContinuousEnvironments(Environments):
class DiscreteEnvironments(Environments): class DiscreteEnvironments(Environments):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs) super().__init__(env, train_envs, test_envs)
self.observation_shape = env.observation_space.shape or env.observation_space.n self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
self.action_shape = env.action_space.shape or env.action_space.n self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
def get_action_shape(self) -> TShape: def get_action_shape(self) -> TActionShape:
return self.action_shape return self.action_shape
def get_observation_shape(self) -> TShape: def get_observation_shape(self) -> TObservationShape:
return self.observation_shape return self.observation_shape
def get_type(self) -> EnvType: def get_type(self) -> EnvType:

View File

@ -40,7 +40,8 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.highlevel.trainer import ( from tianshou.highlevel.trainer import (
TrainerCallbacks, TrainerCallbacks,
TrainerEpochCallback, TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainerStopCallback, TrainerStopCallback,
) )
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
@ -135,24 +136,29 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
result = trainer.run() result = trainer.run()
pprint(result) # TODO logging pprint(result) # TODO logging
render = self.config.render
if render is None:
render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode?
self._watch_agent( self._watch_agent(
self.config.watch_num_episodes, self.config.watch_num_episodes,
policy, policy,
test_collector, test_collector,
self.config.render, render,
) )
@staticmethod @staticmethod
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None: def _watch_agent(
num_episodes: int,
policy: BasePolicy,
test_collector: Collector,
render: float,
) -> None:
policy.eval() policy.eval()
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render) result = test_collector.collect(n_episode=num_episodes, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
class ExperimentBuilder: class ExperimentBuilder:
def __init__( def __init__(
self, self,
@ -177,7 +183,7 @@ class ExperimentBuilder:
self._env_config = config self._env_config = config
return self return self
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder: def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
self._logger_factory = logger_factory self._logger_factory = logger_factory
return self return self
@ -185,16 +191,16 @@ class ExperimentBuilder:
self._policy_wrapper_factory = policy_wrapper_factory self._policy_wrapper_factory = policy_wrapper_factory
return self return self
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder: def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
self._optim_factory = optim_factory self._optim_factory = optim_factory
return self return self
def with_optim_factory_default( def with_optim_factory_default(
self: TBuilder, self,
betas=(0.9, 0.999), betas: tuple[float, float] = (0.9, 0.999),
eps=1e-08, eps: float = 1e-08,
weight_decay=0, weight_decay: float = 0,
) -> TBuilder: ) -> Self:
"""Configures the use of the default optimizer, Adam, with the given parameters. """Configures the use of the default optimizer, Adam, with the given parameters.
:param betas: coefficients used for computing running averages of gradient and its square :param betas: coefficients used for computing running averages of gradient and its square
@ -205,11 +211,11 @@ class ExperimentBuilder:
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
return self return self
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self: def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
self._trainer_callbacks.epoch_callback_train = callback self._trainer_callbacks.epoch_callback_train = callback
return self return self
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self: def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
self._trainer_callbacks.epoch_callback_test = callback self._trainer_callbacks.epoch_callback_test = callback
return self return self
@ -232,7 +238,7 @@ class ExperimentBuilder:
agent_factory.set_trainer_callbacks(self._trainer_callbacks) agent_factory.set_trainer_callbacks(self._trainer_callbacks)
if self._policy_wrapper_factory: if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
experiment = Experiment( experiment: Experiment = Experiment(
self._config, self._config,
self._env_factory, self._env_factory,
agent_factory, agent_factory,
@ -248,18 +254,16 @@ class _BuilderMixinActorFactory:
self._continuous_actor_type = continuous_actor_type self._continuous_actor_type = continuous_actor_type
self._actor_factory: ActorFactory | None = None self._actor_factory: ActorFactory | None = None
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder: def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
self: TBuilder | _BuilderMixinActorFactory
self._actor_factory = actor_factory self._actor_factory = actor_factory
return self return self
def _with_actor_factory_default( def _with_actor_factory_default(
self: TBuilder, self,
hidden_sizes: Sequence[int], hidden_sizes: Sequence[int],
continuous_unbounded=False, continuous_unbounded: bool = False,
continuous_conditioned_sigma=False, continuous_conditioned_sigma: bool = False,
) -> TBuilder: ) -> Self:
self: TBuilder | _BuilderMixinActorFactory
self._actor_factory = ActorFactoryDefault( self._actor_factory = ActorFactoryDefault(
self._continuous_actor_type, self._continuous_actor_type,
hidden_sizes, hidden_sizes,
@ -268,7 +272,7 @@ class _BuilderMixinActorFactory:
) )
return self return self
def _get_actor_factory(self): def _get_actor_factory(self) -> ActorFactory:
if self._actor_factory is None: if self._actor_factory is None:
return ActorFactoryDefault(self._continuous_actor_type) return ActorFactoryDefault(self._continuous_actor_type)
else: else:
@ -278,14 +282,14 @@ class _BuilderMixinActorFactory:
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self): def __init__(self) -> None:
super().__init__(ContinuousActorType.GAUSSIAN) super().__init__(ContinuousActorType.GAUSSIAN)
def with_actor_factory_default( def with_actor_factory_default(
self, self,
hidden_sizes: Sequence[int], hidden_sizes: Sequence[int],
continuous_unbounded=False, continuous_unbounded: bool = False,
continuous_conditioned_sigma=False, continuous_conditioned_sigma: bool = False,
) -> Self: ) -> Self:
return super()._with_actor_factory_default( return super()._with_actor_factory_default(
hidden_sizes, hidden_sizes,
@ -297,7 +301,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory): class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self): def __init__(self) -> None:
super().__init__(ContinuousActorType.DETERMINISTIC) super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self: def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
@ -308,15 +312,15 @@ class _BuilderMixinCriticsFactory:
def __init__(self, num_critics: int): def __init__(self, num_critics: int):
self._critic_factories: list[CriticFactory | None] = [None] * num_critics self._critic_factories: list[CriticFactory | None] = [None] * num_critics
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory): def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self:
self._critic_factories[idx] = critic_factory self._critic_factories[idx] = critic_factory
return self return self
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]): def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]) -> Self:
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes) self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
return self return self
def _get_critic_factory(self, idx: int): def _get_critic_factory(self, idx: int) -> CriticFactory:
factory = self._critic_factories[idx] factory = self._critic_factories[idx]
if factory is None: if factory is None:
return CriticFactoryDefault() return CriticFactoryDefault()
@ -325,7 +329,7 @@ class _BuilderMixinCriticsFactory:
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self): def __init__(self) -> None:
super().__init__(1) super().__init__(1)
def with_critic_factory(self, critic_factory: CriticFactory) -> Self: def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
@ -341,7 +345,7 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory): class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._critic_use_actor_module = False self._critic_use_actor_module = False
@ -352,11 +356,10 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self): def __init__(self) -> None:
super().__init__(2) super().__init__(2)
def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
self: TBuilder | "_BuilderMixinDualCriticFactory"
for i in range(len(self._critic_factories)): for i in range(len(self._critic_factories)):
self._with_critic_factory(i, critic_factory) self._with_critic_factory(i, critic_factory)
return self return self
@ -364,35 +367,30 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_common_critic_factory_default( def with_common_critic_factory_default(
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> Self:
self: TBuilder | "_BuilderMixinDualCriticFactory"
for i in range(len(self._critic_factories)): for i in range(len(self._critic_factories)):
self._with_critic_factory_default(i, hidden_sizes) self._with_critic_factory_default(i, hidden_sizes)
return self return self
def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory(0, critic_factory) self._with_critic_factory(0, critic_factory)
return self return self
def with_critic1_factory_default( def with_critic1_factory_default(
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> Self:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
return self return self
def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder: def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory(1, critic_factory) self._with_critic_factory(1, critic_factory)
return self return self
def with_critic2_factory_default( def with_critic2_factory_default(
self, self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder: ) -> Self:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes) self._with_critic_factory_default(0, hidden_sizes)
return self return self

View File

@ -19,7 +19,7 @@ class Logger:
class LoggerFactory(ToStringMixin, ABC): class LoggerFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger: def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
pass pass
@ -48,6 +48,7 @@ class DefaultLoggerFactory(LoggerFactory):
), ),
), ),
) )
logger: TLogger
if self.logger_type == "wandb": if self.logger_type == "wandb":
logger = WandbLogger( logger = WandbLogger(
save_interval=1, save_interval=1,

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum
import torch import torch
from torch import nn from torch import nn
@ -11,7 +12,7 @@ from tianshou.utils.net.common import BaseActor, Net
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
class ContinuousActorType: class ContinuousActorType(Enum):
GAUSSIAN = "gaussian" GAUSSIAN = "gaussian"
DETERMINISTIC = "deterministic" DETERMINISTIC = "deterministic"
UNSUPPORTED = "unsupported" UNSUPPORTED = "unsupported"
@ -23,7 +24,7 @@ class ActorFactory(ToStringMixin, ABC):
pass pass
@staticmethod @staticmethod
def _init_linear(actor: torch.nn.Module): def _init_linear(actor: torch.nn.Module) -> None:
"""Initializes linear layers of an actor module using default mechanisms. """Initializes linear layers of an actor module using default mechanisms.
:param module: the actor module. :param module: the actor module.
@ -34,7 +35,7 @@ class ActorFactory(ToStringMixin, ABC):
# do last policy layer scaling, this will make initial actions have (close to) # do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances, # 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details # see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules(): for m in actor.mu.modules(): # type: ignore
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
m.weight.data.copy_(0.01 * m.weight.data) m.weight.data.copy_(0.01 * m.weight.data)
@ -48,8 +49,8 @@ class ActorFactoryDefault(ActorFactory):
self, self,
continuous_actor_type: ContinuousActorType, continuous_actor_type: ContinuousActorType,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded=False, continuous_unbounded: bool = False,
continuous_conditioned_sigma=False, continuous_conditioned_sigma: bool = False,
): ):
self.continuous_actor_type = continuous_actor_type self.continuous_actor_type = continuous_actor_type
self.continuous_unbounded = continuous_unbounded self.continuous_unbounded = continuous_unbounded
@ -58,6 +59,7 @@ class ActorFactoryDefault(ActorFactory):
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
env_type = envs.get_type() env_type = envs.get_type()
factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet
if env_type == EnvType.CONTINUOUS: if env_type == EnvType.CONTINUOUS:
match self.continuous_actor_type: match self.continuous_actor_type:
case ContinuousActorType.GAUSSIAN: case ContinuousActorType.GAUSSIAN:
@ -103,7 +105,12 @@ class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False): def __init__(
self,
hidden_sizes: Sequence[int],
unbounded: bool = True,
conditioned_sigma: bool = False,
):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.unbounded = unbounded self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma self.conditioned_sigma = conditioned_sigma

View File

@ -10,10 +10,10 @@ from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
TDevice: TypeAlias = str | int | torch.device TDevice: TypeAlias = str | torch.device
def init_linear_orthogonal(module: torch.nn.Module): def init_linear_orthogonal(module: torch.nn.Module) -> None:
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0. """Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
:param module: the module whose submodules are to be processed :param module: the module whose submodules are to be processed

View File

@ -27,11 +27,17 @@ class CriticFactoryDefault(CriticFactory):
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type() env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS: if env_type == EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes) return CriticFactoryContinuousNet(self.hidden_sizes).create_module(
return factory.create_module(envs, device, use_action) envs,
device,
use_action,
)
elif env_type == EnvType.DISCRETE: elif env_type == EnvType.DISCRETE:
factory = CriticFactoryDiscreteNet(self.hidden_sizes) return CriticFactoryDiscreteNet(self.hidden_sizes).create_module(
return factory.create_module(envs, device, use_action) envs,
device,
use_action,
)
else: else:
raise ValueError(f"{env_type} not supported") raise ValueError(f"{env_type} not supported")

View File

@ -23,11 +23,11 @@ class ActorCriticModuleOpt:
optim: torch.optim.Optimizer optim: torch.optim.Optimizer
@property @property
def actor(self): def actor(self) -> torch.nn.Module:
return self.actor_critic_module.actor return self.actor_critic_module.actor
@property @property
def critic(self): def critic(self) -> torch.nn.Module:
return self.actor_critic_module.critic return self.actor_critic_module.critic

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any, Protocol
import torch import torch
from torch.optim import Adam, RMSprop from torch.optim import Adam, RMSprop
@ -7,6 +7,11 @@ from torch.optim import Adam, RMSprop
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
class OptimizerWithLearningRateProtocol(Protocol):
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
pass
class OptimizerFactory(ABC, ToStringMixin): class OptimizerFactory(ABC, ToStringMixin):
# TODO: Is it OK to assume that all optimizers have a learning rate argument? # TODO: Is it OK to assume that all optimizers have a learning rate argument?
# Right now, the learning rate is typically a configuration parameter. # Right now, the learning rate is typically a configuration parameter.
@ -18,7 +23,7 @@ class OptimizerFactory(ABC, ToStringMixin):
class OptimizerFactoryTorch(OptimizerFactory): class OptimizerFactoryTorch(OptimizerFactory):
def __init__(self, optim_class: Any, **kwargs): def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any):
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), """:param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
which will be passed the module parameters, the learning rate as `lr` and the which will be passed the module parameters, the learning rate as `lr` and the
kwargs provided. kwargs provided.
@ -32,7 +37,12 @@ class OptimizerFactoryTorch(OptimizerFactory):
class OptimizerFactoryAdam(OptimizerFactory): class OptimizerFactoryAdam(OptimizerFactory):
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0): def __init__(
self,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
weight_decay: float = 0,
):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.eps = eps self.eps = eps
self.betas = betas self.betas = betas
@ -48,7 +58,14 @@ class OptimizerFactoryAdam(OptimizerFactory):
class OptimizerFactoryRMSprop(OptimizerFactory): class OptimizerFactoryRMSprop(OptimizerFactory):
def __init__(self, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False): def __init__(
self,
alpha: float = 0.99,
eps: float = 1e-08,
weight_decay: float = 0,
momentum: float = 0,
centered: bool = False,
):
self.alpha = alpha self.alpha = alpha
self.momentum = momentum self.momentum = momentum
self.centered = centered self.centered = centered

View File

@ -30,7 +30,7 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
device: TDevice, device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape()) target_entropy = float(-np.prod(envs.get_action_shape()))
log_alpha = torch.zeros(1, requires_grad=True, device=device) log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr) alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim return target_entropy, log_alpha, alpha_optim

View File

@ -19,21 +19,21 @@ class DistributionFunctionFactory(ToStringMixin, ABC):
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
assert envs.get_type().assert_discrete(self) envs.get_type().assert_discrete(self)
return self._dist_fn return self._dist_fn
@staticmethod @staticmethod
def _dist_fn(p): def _dist_fn(p: TDistParams) -> torch.distributions.Distribution:
return torch.distributions.Categorical(logits=p) return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory): class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
assert envs.get_type().assert_continuous(self) envs.get_type().assert_continuous(self)
return self._dist_fn return self._dist_fn
@staticmethod @staticmethod
def _dist_fn(*p): def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Independent(torch.distributions.Normal(*p), 1) return torch.distributions.Independent(torch.distributions.Normal(*p), 1)

View File

@ -3,6 +3,7 @@ from dataclasses import asdict, dataclass
from typing import Any, Literal, Protocol from typing import Any, Literal, Protocol
import torch import torch
from torch.optim.lr_scheduler import LRScheduler
from tianshou.exploration import BaseNoise from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
@ -71,7 +72,7 @@ class ParamTransformerChangeValue(ParamTransformer):
def __init__(self, key: str): def __init__(self, key: str):
self.key = key self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData): def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
params[self.key] = self.change_value(params[self.key], data) params[self.key] = self.change_value(params[self.key], data)
@abstractmethod @abstractmethod
@ -118,6 +119,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
) )
if lr_scheduler_factory is not None: if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
lr_scheduler: LRScheduler | MultipleLRSchedulers | None
match len(lr_schedulers): match len(lr_schedulers):
case 0: case 0:
lr_scheduler = None lr_scheduler = None
@ -140,6 +142,7 @@ class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
self.key_scheduler = key_scheduler self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.actor is not None and data.critic1 is not None
transformer = ParamTransformerMultiLRScheduler( transformer = ParamTransformerMultiLRScheduler(
[ [
(data.actor.optim, self.key_factory_actor), (data.actor.optim, self.key_factory_actor),
@ -164,6 +167,7 @@ class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
self.key_scheduler = key_scheduler self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.actor is not None and data.critic1 is not None and data.critic2 is not None
transformer = ParamTransformerMultiLRScheduler( transformer = ParamTransformerMultiLRScheduler(
[ [
(data.actor.optim, self.key_factory_actor), (data.actor.optim, self.key_factory_actor),
@ -247,7 +251,7 @@ class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
lr: float = 1e-3 lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self): def _get_param_transformers(self) -> list[ParamTransformer]:
return [ return [
ParamTransformerDrop("lr"), ParamTransformerDrop("lr"),
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"), ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
@ -261,7 +265,7 @@ class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
actor_lr_scheduler_factory: LRSchedulerFactory | None = None actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic_lr_scheduler_factory: LRSchedulerFactory | None = None critic_lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self): def _get_param_transformers(self) -> list[ParamTransformer]:
return [ return [
ParamTransformerDrop("actor_lr", "critic_lr"), ParamTransformerDrop("actor_lr", "critic_lr"),
ParamTransformerActorAndCriticLRScheduler( ParamTransformerActorAndCriticLRScheduler(
@ -325,7 +329,7 @@ class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self): def _get_param_transformers(self) -> list[ParamTransformer]:
return [ return [
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerActorDualCriticsLRScheduler( ParamTransformerActorDualCriticsLRScheduler(
@ -348,7 +352,7 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
action_scaling: bool = True action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip" action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self): def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers() transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerAutoAlpha("alpha")) transformers.append(ParamTransformerAutoAlpha("alpha"))
@ -365,7 +369,7 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
is_double: bool = True is_double: bool = True
clip_loss_grad: bool = False clip_loss_grad: bool = False
def _get_param_transformers(self): def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers() transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
return transformers return transformers
@ -380,7 +384,7 @@ class DDPGParams(Params, ParamsMixinActorAndCritic):
action_scaling: bool = True action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip" action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self): def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers() transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise")) transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
@ -399,7 +403,7 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
action_scaling: bool = True action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip" action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self): def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers() transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise")) transformers.append(ParamTransformerNoiseFactory("exploration_noise"))

View File

@ -36,7 +36,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
lr: float, lr: float,
lr_scale: float, lr_scale: float,
reward_scale: float, reward_scale: float,
forward_loss_weight, forward_loss_weight: float,
): ):
self.feature_net_factory = feature_net_factory self.feature_net_factory = feature_net_factory
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
@ -54,6 +54,8 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
) -> ICMPolicy: ) -> ICMPolicy:
feature_net = self.feature_net_factory.create_module(envs, device) feature_net = self.feature_net_factory.create_module(envs, device)
action_dim = envs.get_action_shape() action_dim = envs.get_action_shape()
if type(action_dim) != int:
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
feature_dim = feature_net.output_dim feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule( icm_net = IntrinsicCuriosityModule(
feature_net.module, feature_net.module,

View File

@ -18,7 +18,7 @@ class TrainingContext:
self.logger = logger self.logger = logger
class TrainerEpochCallback(ToStringMixin, ABC): class TrainerEpochCallbackTrain(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch.""" """Callback which is called at the beginning of each epoch."""
@abstractmethod @abstractmethod
@ -26,7 +26,21 @@ class TrainerEpochCallback(ToStringMixin, ABC):
pass pass
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]: def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
def fn(epoch, env_step): def fn(epoch: int, env_step: int) -> None:
return self.callback(epoch, env_step, context)
return fn
class TrainerEpochCallbackTest(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
@abstractmethod
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
pass
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]:
def fn(epoch: int, env_step: int | None) -> None:
return self.callback(epoch, env_step, context) return self.callback(epoch, env_step, context)
return fn return fn
@ -42,7 +56,7 @@ class TrainerStopCallback(ToStringMixin, ABC):
""" """
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]: def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
def fn(mean_rewards: float): def fn(mean_rewards: float) -> bool:
return self.should_stop(mean_rewards, context) return self.should_stop(mean_rewards, context)
return fn return fn
@ -50,6 +64,6 @@ class TrainerStopCallback(ToStringMixin, ABC):
@dataclass @dataclass
class TrainerCallbacks: class TrainerCallbacks:
epoch_callback_train: TrainerEpochCallback | None = None epoch_callback_train: TrainerEpochCallbackTrain | None = None
epoch_callback_test: TrainerEpochCallback | None = None epoch_callback_test: TrainerEpochCallbackTest | None = None
stop_callback: TrainerStopCallback | None = None stop_callback: TrainerStopCallback | None = None

View File

@ -1,6 +1,6 @@
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Literal, cast from typing import Any, Literal, TypeAlias, cast
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@ -17,7 +17,7 @@ from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils import RunningMeanStd from tianshou.utils import RunningMeanStd
TDistParams = torch.Tensor | tuple[torch.Tensor] TDistParams: TypeAlias = torch.Tensor | [torch.Tensor, torch.Tensor]
class PGPolicy(BasePolicy): class PGPolicy(BasePolicy):

View File

@ -15,7 +15,7 @@ class MultipleLRSchedulers:
policy = PPOPolicy(..., lr_scheduler=scheduler) policy = PPOPolicy(..., lr_scheduler=scheduler)
""" """
def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR): def __init__(self, *args: torch.optim.lr_scheduler.LRScheduler):
self.schedulers = args self.schedulers = args
def step(self) -> None: def step(self) -> None:

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, no_type_check from typing import Any, TypeAlias, no_type_check
import numpy as np import numpy as np
import torch import torch
@ -11,6 +11,7 @@ from tianshou.data.types import RecurrentStateBatch
ModuleType = type[nn.Module] ModuleType = type[nn.Module]
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
TActionShape: TypeAlias = Sequence[int] | int
def miniblock( def miniblock(

View File

@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from tianshou.utils.net.common import MLP, BaseActor from tianshou.utils.net.common import MLP, BaseActor, TActionShape
SIGMA_MIN = -20 SIGMA_MIN = -20
SIGMA_MAX = 2 SIGMA_MAX = 2
@ -40,7 +40,7 @@ class Actor(BaseActor):
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module,
action_shape: Sequence[int], action_shape: TActionShape,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
max_action: float = 1.0, max_action: float = 1.0,
device: str | int | torch.device = "cpu", device: str | int | torch.device = "cpu",
@ -180,7 +180,7 @@ class ActorProb(BaseActor):
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module,
action_shape: Sequence[int], action_shape: TActionShape,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
max_action: float = 1.0, max_action: float = 1.0,
device: str | int | torch.device = "cpu", device: str | int | torch.device = "cpu",

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from tianshou.data import Batch, to_torch from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor from tianshou.utils.net.common import MLP, BaseActor, TActionShape
class Actor(BaseActor): class Actor(BaseActor):
@ -39,7 +39,7 @@ class Actor(BaseActor):
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module,
action_shape: Sequence[int], action_shape: TActionShape,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
softmax_output: bool = True, softmax_output: bool = True,
preprocess_net_output_dim: int | None = None, preprocess_net_output_dim: int | None = None,