Improve type annotations, fix type issues and add checks

This commit is contained in:
Dominik Jain 2023-10-09 17:22:52 +02:00
parent e6716326bd
commit a161a9cf58
21 changed files with 191 additions and 123 deletions

View File

@ -19,7 +19,11 @@ from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.highlevel.trainer import TrainerEpochCallback, TrainingContext
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainingContext,
)
from tianshou.policy import DQNPolicy
from tianshou.utils import logging
@ -75,7 +79,7 @@ def main(
scale=scale_obs,
)
class TrainEpochCallback(TrainerEpochCallback):
class TrainEpochCallback(TrainerEpochCallbackTrain):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger.logger
@ -88,7 +92,7 @@ def main(
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
class TestEpochCallback(TrainerEpochCallback):
class TestEpochCallback(TrainerEpochCallbackTest):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
policy.set_eps(eps_test)

View File

@ -13,7 +13,9 @@ from tianshou.highlevel.experiment import (
ExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryIndependentGaussians
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Generic, TypeVar
from os import PathLike
from typing import Any, Generic, TypeVar, cast
import gymnasium
import torch
@ -60,9 +61,14 @@ class AgentFactory(ABC, ToStringMixin):
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
def create_train_test_collector(
self,
policy: BasePolicy,
envs: Environments,
) -> tuple[Collector, Collector]:
buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs
buffer: ReplayBuffer
if len(train_envs) > 1:
buffer = VectorReplayBuffer(
buffer_size,
@ -90,7 +96,7 @@ class AgentFactory(ABC, ToStringMixin):
) -> None:
self.policy_wrapper_factory = policy_wrapper_factory
def set_trainer_callbacks(self, callbacks: TrainerCallbacks):
def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None:
self.trainer_callbacks = callbacks
@abstractmethod
@ -122,7 +128,12 @@ class AgentFactory(ABC, ToStringMixin):
return save_best_fn
@staticmethod
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
def load_checkpoint(
policy: torch.nn.Module,
path: str | PathLike,
envs: Environments,
device: TDevice,
) -> None:
ckpt = torch.load(path, map_location=device)
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
if envs.train_envs:
@ -280,6 +291,7 @@ class _ActorCriticMixin:
lr: float,
) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device)
critic: torch.nn.Module
if self.critic_use_actor_module:
if self.critic_use_action:
raise ValueError(
@ -375,7 +387,7 @@ class ActorCriticAgentFactory(
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
pass
def _create_kwargs(self, envs: Environments, device: TDevice):
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
actor_critic = self._create_actor_critic(envs, device)
kwargs = self.params.create_kwargs(
ParamTransformerData(
@ -468,8 +480,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
),
)
envs.get_type().assert_discrete(self)
# noinspection PyTypeChecker
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space())
return DQNPolicy(
model=model,
optim=optim,

View File

@ -1,37 +1,38 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import Enum
from typing import Any
from typing import Any, TypeAlias
import gymnasium as gym
from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.utils.net.common import TActionShape
TShape = int | Sequence[int]
TObservationShape: TypeAlias = int | Sequence[int]
class EnvType(Enum):
CONTINUOUS = "continuous"
DISCRETE = "discrete"
def is_discrete(self):
def is_discrete(self) -> bool:
return self == EnvType.DISCRETE
def is_continuous(self):
def is_continuous(self) -> bool:
return self == EnvType.CONTINUOUS
def assert_continuous(self, requiring_entity: Any):
def assert_continuous(self, requiring_entity: Any) -> None:
if not self.is_continuous():
raise AssertionError(f"{requiring_entity} requires continuous environments")
def assert_discrete(self, requiring_entity: Any):
def assert_discrete(self, requiring_entity: Any) -> None:
if not self.is_discrete():
raise AssertionError(f"{requiring_entity} requires discrete environments")
class Environments(ABC):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
self.train_envs = train_envs
self.test_envs = test_envs
@ -43,11 +44,11 @@ class Environments(ABC):
}
@abstractmethod
def get_action_shape(self) -> TShape:
def get_action_shape(self) -> TActionShape:
pass
@abstractmethod
def get_observation_shape(self) -> TShape:
def get_observation_shape(self) -> TObservationShape:
pass
def get_action_space(self) -> gym.Space:
@ -62,11 +63,11 @@ class Environments(ABC):
class ContinuousEnvironments(Environments):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
def info(self):
def info(self) -> dict[str, Any]:
d = super().info()
d["max_action"] = self.max_action
return d
@ -80,17 +81,17 @@ class ContinuousEnvironments(Environments):
"Only environments with continuous action space are supported here. "
f"But got env with action space: {env.action_space.__class__}.",
)
state_shape = env.observation_space.shape or env.observation_space.n
state_shape = env.observation_space.shape or env.observation_space.n # type: ignore
if not state_shape:
raise ValueError("Observation space shape is not defined")
action_shape = env.action_space.shape
max_action = env.action_space.high[0]
return state_shape, action_shape, max_action
def get_action_shape(self) -> TShape:
def get_action_shape(self) -> TActionShape:
return self.action_shape
def get_observation_shape(self) -> TShape:
def get_observation_shape(self) -> TObservationShape:
return self.state_shape
def get_type(self) -> EnvType:
@ -98,15 +99,15 @@ class ContinuousEnvironments(Environments):
class DiscreteEnvironments(Environments):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.observation_shape = env.observation_space.shape or env.observation_space.n
self.action_shape = env.action_space.shape or env.action_space.n
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
def get_action_shape(self) -> TShape:
def get_action_shape(self) -> TActionShape:
return self.action_shape
def get_observation_shape(self) -> TShape:
def get_observation_shape(self) -> TObservationShape:
return self.observation_shape
def get_type(self) -> EnvType:

View File

@ -40,7 +40,8 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.highlevel.trainer import (
TrainerCallbacks,
TrainerEpochCallback,
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainerStopCallback,
)
from tianshou.policy import BasePolicy
@ -135,24 +136,29 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
result = trainer.run()
pprint(result) # TODO logging
render = self.config.render
if render is None:
render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode?
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.config.render,
render,
)
@staticmethod
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
def _watch_agent(
num_episodes: int,
policy: BasePolicy,
test_collector: Collector,
render: float,
) -> None:
policy.eval()
test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
class ExperimentBuilder:
def __init__(
self,
@ -177,7 +183,7 @@ class ExperimentBuilder:
self._env_config = config
return self
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
self._logger_factory = logger_factory
return self
@ -185,16 +191,16 @@ class ExperimentBuilder:
self._policy_wrapper_factory = policy_wrapper_factory
return self
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
self._optim_factory = optim_factory
return self
def with_optim_factory_default(
self: TBuilder,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0,
) -> TBuilder:
self,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
weight_decay: float = 0,
) -> Self:
"""Configures the use of the default optimizer, Adam, with the given parameters.
:param betas: coefficients used for computing running averages of gradient and its square
@ -205,11 +211,11 @@ class ExperimentBuilder:
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
return self
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self:
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
self._trainer_callbacks.epoch_callback_train = callback
return self
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self:
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
self._trainer_callbacks.epoch_callback_test = callback
return self
@ -232,7 +238,7 @@ class ExperimentBuilder:
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
experiment = Experiment(
experiment: Experiment = Experiment(
self._config,
self._env_factory,
agent_factory,
@ -248,18 +254,16 @@ class _BuilderMixinActorFactory:
self._continuous_actor_type = continuous_actor_type
self._actor_factory: ActorFactory | None = None
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
self: TBuilder | _BuilderMixinActorFactory
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
self._actor_factory = actor_factory
return self
def _with_actor_factory_default(
self: TBuilder,
self,
hidden_sizes: Sequence[int],
continuous_unbounded=False,
continuous_conditioned_sigma=False,
) -> TBuilder:
self: TBuilder | _BuilderMixinActorFactory
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
) -> Self:
self._actor_factory = ActorFactoryDefault(
self._continuous_actor_type,
hidden_sizes,
@ -268,7 +272,7 @@ class _BuilderMixinActorFactory:
)
return self
def _get_actor_factory(self):
def _get_actor_factory(self) -> ActorFactory:
if self._actor_factory is None:
return ActorFactoryDefault(self._continuous_actor_type)
else:
@ -278,14 +282,14 @@ class _BuilderMixinActorFactory:
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self):
def __init__(self) -> None:
super().__init__(ContinuousActorType.GAUSSIAN)
def with_actor_factory_default(
self,
hidden_sizes: Sequence[int],
continuous_unbounded=False,
continuous_conditioned_sigma=False,
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
) -> Self:
return super()._with_actor_factory_default(
hidden_sizes,
@ -297,7 +301,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self):
def __init__(self) -> None:
super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
@ -308,15 +312,15 @@ class _BuilderMixinCriticsFactory:
def __init__(self, num_critics: int):
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory):
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self:
self._critic_factories[idx] = critic_factory
return self
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]) -> Self:
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
return self
def _get_critic_factory(self, idx: int):
def _get_critic_factory(self, idx: int) -> CriticFactory:
factory = self._critic_factories[idx]
if factory is None:
return CriticFactoryDefault()
@ -325,7 +329,7 @@ class _BuilderMixinCriticsFactory:
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__(1)
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
@ -341,7 +345,7 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._critic_use_actor_module = False
@ -352,11 +356,10 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__(2)
def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
for i in range(len(self._critic_factories)):
self._with_critic_factory(i, critic_factory)
return self
@ -364,35 +367,30 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def with_common_critic_factory_default(
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
) -> Self:
for i in range(len(self._critic_factories)):
self._with_critic_factory_default(i, hidden_sizes)
return self
def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
self._with_critic_factory(0, critic_factory)
return self
def with_critic1_factory_default(
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
) -> Self:
self._with_critic_factory_default(0, hidden_sizes)
return self
def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
self._with_critic_factory(1, critic_factory)
return self
def with_critic2_factory_default(
self,
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
) -> Self:
self._with_critic_factory_default(0, hidden_sizes)
return self

View File

@ -19,7 +19,7 @@ class Logger:
class LoggerFactory(ToStringMixin, ABC):
@abstractmethod
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
pass
@ -48,6 +48,7 @@ class DefaultLoggerFactory(LoggerFactory):
),
),
)
logger: TLogger
if self.logger_type == "wandb":
logger = WandbLogger(
save_interval=1,

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import Enum
import torch
from torch import nn
@ -11,7 +12,7 @@ from tianshou.utils.net.common import BaseActor, Net
from tianshou.utils.string import ToStringMixin
class ContinuousActorType:
class ContinuousActorType(Enum):
GAUSSIAN = "gaussian"
DETERMINISTIC = "deterministic"
UNSUPPORTED = "unsupported"
@ -23,7 +24,7 @@ class ActorFactory(ToStringMixin, ABC):
pass
@staticmethod
def _init_linear(actor: torch.nn.Module):
def _init_linear(actor: torch.nn.Module) -> None:
"""Initializes linear layers of an actor module using default mechanisms.
:param module: the actor module.
@ -34,7 +35,7 @@ class ActorFactory(ToStringMixin, ABC):
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
for m in actor.mu.modules(): # type: ignore
if isinstance(m, torch.nn.Linear):
m.weight.data.copy_(0.01 * m.weight.data)
@ -48,8 +49,8 @@ class ActorFactoryDefault(ActorFactory):
self,
continuous_actor_type: ContinuousActorType,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded=False,
continuous_conditioned_sigma=False,
continuous_unbounded: bool = False,
continuous_conditioned_sigma: bool = False,
):
self.continuous_actor_type = continuous_actor_type
self.continuous_unbounded = continuous_unbounded
@ -58,6 +59,7 @@ class ActorFactoryDefault(ActorFactory):
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
env_type = envs.get_type()
factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet
if env_type == EnvType.CONTINUOUS:
match self.continuous_actor_type:
case ContinuousActorType.GAUSSIAN:
@ -103,7 +105,12 @@ class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
def __init__(
self,
hidden_sizes: Sequence[int],
unbounded: bool = True,
conditioned_sigma: bool = False,
):
self.hidden_sizes = hidden_sizes
self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma

View File

@ -10,10 +10,10 @@ from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net
from tianshou.utils.string import ToStringMixin
TDevice: TypeAlias = str | int | torch.device
TDevice: TypeAlias = str | torch.device
def init_linear_orthogonal(module: torch.nn.Module):
def init_linear_orthogonal(module: torch.nn.Module) -> None:
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
:param module: the module whose submodules are to be processed

View File

@ -27,11 +27,17 @@ class CriticFactoryDefault(CriticFactory):
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
return CriticFactoryContinuousNet(self.hidden_sizes).create_module(
envs,
device,
use_action,
)
elif env_type == EnvType.DISCRETE:
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
return CriticFactoryDiscreteNet(self.hidden_sizes).create_module(
envs,
device,
use_action,
)
else:
raise ValueError(f"{env_type} not supported")

View File

@ -23,11 +23,11 @@ class ActorCriticModuleOpt:
optim: torch.optim.Optimizer
@property
def actor(self):
def actor(self) -> torch.nn.Module:
return self.actor_critic_module.actor
@property
def critic(self):
def critic(self) -> torch.nn.Module:
return self.actor_critic_module.critic

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Protocol
import torch
from torch.optim import Adam, RMSprop
@ -7,6 +7,11 @@ from torch.optim import Adam, RMSprop
from tianshou.utils.string import ToStringMixin
class OptimizerWithLearningRateProtocol(Protocol):
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
pass
class OptimizerFactory(ABC, ToStringMixin):
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
# Right now, the learning rate is typically a configuration parameter.
@ -18,7 +23,7 @@ class OptimizerFactory(ABC, ToStringMixin):
class OptimizerFactoryTorch(OptimizerFactory):
def __init__(self, optim_class: Any, **kwargs):
def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any):
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
which will be passed the module parameters, the learning rate as `lr` and the
kwargs provided.
@ -32,7 +37,12 @@ class OptimizerFactoryTorch(OptimizerFactory):
class OptimizerFactoryAdam(OptimizerFactory):
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
def __init__(
self,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
weight_decay: float = 0,
):
self.weight_decay = weight_decay
self.eps = eps
self.betas = betas
@ -48,7 +58,14 @@ class OptimizerFactoryAdam(OptimizerFactory):
class OptimizerFactoryRMSprop(OptimizerFactory):
def __init__(self, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False):
def __init__(
self,
alpha: float = 0.99,
eps: float = 1e-08,
weight_decay: float = 0,
momentum: float = 0,
centered: bool = False,
):
self.alpha = alpha
self.momentum = momentum
self.centered = centered

View File

@ -30,7 +30,7 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
target_entropy = float(-np.prod(envs.get_action_shape()))
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim

View File

@ -19,21 +19,21 @@ class DistributionFunctionFactory(ToStringMixin, ABC):
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
assert envs.get_type().assert_discrete(self)
envs.get_type().assert_discrete(self)
return self._dist_fn
@staticmethod
def _dist_fn(p):
def _dist_fn(p: TDistParams) -> torch.distributions.Distribution:
return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
assert envs.get_type().assert_continuous(self)
envs.get_type().assert_continuous(self)
return self._dist_fn
@staticmethod
def _dist_fn(*p):
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)

View File

@ -3,6 +3,7 @@ from dataclasses import asdict, dataclass
from typing import Any, Literal, Protocol
import torch
from torch.optim.lr_scheduler import LRScheduler
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
@ -71,7 +72,7 @@ class ParamTransformerChangeValue(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData):
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
params[self.key] = self.change_value(params[self.key], data)
@abstractmethod
@ -118,6 +119,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
)
if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
lr_scheduler: LRScheduler | MultipleLRSchedulers | None
match len(lr_schedulers):
case 0:
lr_scheduler = None
@ -140,6 +142,7 @@ class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.actor is not None and data.critic1 is not None
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
@ -164,6 +167,7 @@ class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
self.key_scheduler = key_scheduler
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
assert data.actor is not None and data.critic1 is not None and data.critic2 is not None
transformer = ParamTransformerMultiLRScheduler(
[
(data.actor.optim, self.key_factory_actor),
@ -247,7 +251,7 @@ class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self):
def _get_param_transformers(self) -> list[ParamTransformer]:
return [
ParamTransformerDrop("lr"),
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
@ -261,7 +265,7 @@ class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self):
def _get_param_transformers(self) -> list[ParamTransformer]:
return [
ParamTransformerDrop("actor_lr", "critic_lr"),
ParamTransformerActorAndCriticLRScheduler(
@ -325,7 +329,7 @@ class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
def _get_param_transformers(self):
def _get_param_transformers(self) -> list[ParamTransformer]:
return [
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerActorDualCriticsLRScheduler(
@ -348,7 +352,7 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self):
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerAutoAlpha("alpha"))
@ -365,7 +369,7 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
is_double: bool = True
clip_loss_grad: bool = False
def _get_param_transformers(self):
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
return transformers
@ -380,7 +384,7 @@ class DDPGParams(Params, ParamsMixinActorAndCritic):
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self):
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
@ -399,7 +403,7 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
def _get_param_transformers(self):
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))

View File

@ -36,7 +36,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
lr: float,
lr_scale: float,
reward_scale: float,
forward_loss_weight,
forward_loss_weight: float,
):
self.feature_net_factory = feature_net_factory
self.hidden_sizes = hidden_sizes
@ -54,6 +54,8 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
) -> ICMPolicy:
feature_net = self.feature_net_factory.create_module(envs, device)
action_dim = envs.get_action_shape()
if type(action_dim) != int:
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.module,

View File

@ -18,7 +18,7 @@ class TrainingContext:
self.logger = logger
class TrainerEpochCallback(ToStringMixin, ABC):
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
@abstractmethod
@ -26,7 +26,21 @@ class TrainerEpochCallback(ToStringMixin, ABC):
pass
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
def fn(epoch, env_step):
def fn(epoch: int, env_step: int) -> None:
return self.callback(epoch, env_step, context)
return fn
class TrainerEpochCallbackTest(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
@abstractmethod
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
pass
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]:
def fn(epoch: int, env_step: int | None) -> None:
return self.callback(epoch, env_step, context)
return fn
@ -42,7 +56,7 @@ class TrainerStopCallback(ToStringMixin, ABC):
"""
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
def fn(mean_rewards: float):
def fn(mean_rewards: float) -> bool:
return self.should_stop(mean_rewards, context)
return fn
@ -50,6 +64,6 @@ class TrainerStopCallback(ToStringMixin, ABC):
@dataclass
class TrainerCallbacks:
epoch_callback_train: TrainerEpochCallback | None = None
epoch_callback_test: TrainerEpochCallback | None = None
epoch_callback_train: TrainerEpochCallbackTrain | None = None
epoch_callback_test: TrainerEpochCallbackTest | None = None
stop_callback: TrainerStopCallback | None = None

View File

@ -1,6 +1,6 @@
import warnings
from collections.abc import Callable
from typing import Any, Literal, cast
from typing import Any, Literal, TypeAlias, cast
import gymnasium as gym
import numpy as np
@ -17,7 +17,7 @@ from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils import RunningMeanStd
TDistParams = torch.Tensor | tuple[torch.Tensor]
TDistParams: TypeAlias = torch.Tensor | [torch.Tensor, torch.Tensor]
class PGPolicy(BasePolicy):

View File

@ -15,7 +15,7 @@ class MultipleLRSchedulers:
policy = PPOPolicy(..., lr_scheduler=scheduler)
"""
def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR):
def __init__(self, *args: torch.optim.lr_scheduler.LRScheduler):
self.schedulers = args
def step(self) -> None:

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import Any, no_type_check
from typing import Any, TypeAlias, no_type_check
import numpy as np
import torch
@ -11,6 +11,7 @@ from tianshou.data.types import RecurrentStateBatch
ModuleType = type[nn.Module]
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
TActionShape: TypeAlias = Sequence[int] | int
def miniblock(

View File

@ -6,7 +6,7 @@ import numpy as np
import torch
from torch import nn
from tianshou.utils.net.common import MLP, BaseActor
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
SIGMA_MIN = -20
SIGMA_MAX = 2
@ -40,7 +40,7 @@ class Actor(BaseActor):
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
action_shape: TActionShape,
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
device: str | int | torch.device = "cpu",
@ -180,7 +180,7 @@ class ActorProb(BaseActor):
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
action_shape: TActionShape,
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
device: str | int | torch.device = "cpu",

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from torch import nn
from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
class Actor(BaseActor):
@ -39,7 +39,7 @@ class Actor(BaseActor):
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
action_shape: TActionShape,
hidden_sizes: Sequence[int] = (),
softmax_output: bool = True,
preprocess_net_output_dim: int | None = None,