Improve type annotations, fix type issues and add checks
This commit is contained in:
parent
e6716326bd
commit
a161a9cf58
@ -19,7 +19,11 @@ from tianshou.highlevel.params.policy_params import DQNParams
|
||||
from tianshou.highlevel.params.policy_wrapper import (
|
||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||
)
|
||||
from tianshou.highlevel.trainer import TrainerEpochCallback, TrainingContext
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTest,
|
||||
TrainerEpochCallbackTrain,
|
||||
TrainingContext,
|
||||
)
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.utils import logging
|
||||
|
||||
@ -75,7 +79,7 @@ def main(
|
||||
scale=scale_obs,
|
||||
)
|
||||
|
||||
class TrainEpochCallback(TrainerEpochCallback):
|
||||
class TrainEpochCallback(TrainerEpochCallbackTrain):
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
logger = context.logger.logger
|
||||
@ -88,7 +92,7 @@ def main(
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
class TestEpochCallback(TrainerEpochCallback):
|
||||
class TestEpochCallback(TrainerEpochCallbackTest):
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
policy.set_eps(eps_test)
|
||||
|
@ -13,7 +13,9 @@ from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
PPOExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryIndependentGaussians
|
||||
from tianshou.highlevel.params.dist_fn import (
|
||||
DistributionFunctionFactoryIndependentGaussians,
|
||||
)
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
from tianshou.utils import logging
|
||||
|
@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Generic, TypeVar
|
||||
from os import PathLike
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
|
||||
import gymnasium
|
||||
import torch
|
||||
@ -60,9 +61,14 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||
|
||||
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
||||
def create_train_test_collector(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
envs: Environments,
|
||||
) -> tuple[Collector, Collector]:
|
||||
buffer_size = self.sampling_config.buffer_size
|
||||
train_envs = envs.train_envs
|
||||
buffer: ReplayBuffer
|
||||
if len(train_envs) > 1:
|
||||
buffer = VectorReplayBuffer(
|
||||
buffer_size,
|
||||
@ -90,7 +96,7 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
) -> None:
|
||||
self.policy_wrapper_factory = policy_wrapper_factory
|
||||
|
||||
def set_trainer_callbacks(self, callbacks: TrainerCallbacks):
|
||||
def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None:
|
||||
self.trainer_callbacks = callbacks
|
||||
|
||||
@abstractmethod
|
||||
@ -122,7 +128,12 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
return save_best_fn
|
||||
|
||||
@staticmethod
|
||||
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
|
||||
def load_checkpoint(
|
||||
policy: torch.nn.Module,
|
||||
path: str | PathLike,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
) -> None:
|
||||
ckpt = torch.load(path, map_location=device)
|
||||
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
||||
if envs.train_envs:
|
||||
@ -280,6 +291,7 @@ class _ActorCriticMixin:
|
||||
lr: float,
|
||||
) -> ActorCriticModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic: torch.nn.Module
|
||||
if self.critic_use_actor_module:
|
||||
if self.critic_use_action:
|
||||
raise ValueError(
|
||||
@ -375,7 +387,7 @@ class ActorCriticAgentFactory(
|
||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||
pass
|
||||
|
||||
def _create_kwargs(self, envs: Environments, device: TDevice):
|
||||
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
|
||||
actor_critic = self._create_actor_critic(envs, device)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
@ -468,8 +480,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
),
|
||||
)
|
||||
envs.get_type().assert_discrete(self)
|
||||
# noinspection PyTypeChecker
|
||||
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
|
||||
action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space())
|
||||
return DQNPolicy(
|
||||
model=model,
|
||||
optim=optim,
|
||||
|
@ -1,37 +1,38 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.utils.net.common import TActionShape
|
||||
|
||||
TShape = int | Sequence[int]
|
||||
TObservationShape: TypeAlias = int | Sequence[int]
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
CONTINUOUS = "continuous"
|
||||
DISCRETE = "discrete"
|
||||
|
||||
def is_discrete(self):
|
||||
def is_discrete(self) -> bool:
|
||||
return self == EnvType.DISCRETE
|
||||
|
||||
def is_continuous(self):
|
||||
def is_continuous(self) -> bool:
|
||||
return self == EnvType.CONTINUOUS
|
||||
|
||||
def assert_continuous(self, requiring_entity: Any):
|
||||
def assert_continuous(self, requiring_entity: Any) -> None:
|
||||
if not self.is_continuous():
|
||||
raise AssertionError(f"{requiring_entity} requires continuous environments")
|
||||
|
||||
def assert_discrete(self, requiring_entity: Any):
|
||||
def assert_discrete(self, requiring_entity: Any) -> None:
|
||||
if not self.is_discrete():
|
||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||
|
||||
|
||||
class Environments(ABC):
|
||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
self.env = env
|
||||
self.train_envs = train_envs
|
||||
self.test_envs = test_envs
|
||||
@ -43,11 +44,11 @@ class Environments(ABC):
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_action_shape(self) -> TShape:
|
||||
def get_action_shape(self) -> TActionShape:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_observation_shape(self) -> TShape:
|
||||
def get_observation_shape(self) -> TObservationShape:
|
||||
pass
|
||||
|
||||
def get_action_space(self) -> gym.Space:
|
||||
@ -62,11 +63,11 @@ class Environments(ABC):
|
||||
|
||||
|
||||
class ContinuousEnvironments(Environments):
|
||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||
|
||||
def info(self):
|
||||
def info(self) -> dict[str, Any]:
|
||||
d = super().info()
|
||||
d["max_action"] = self.max_action
|
||||
return d
|
||||
@ -80,17 +81,17 @@ class ContinuousEnvironments(Environments):
|
||||
"Only environments with continuous action space are supported here. "
|
||||
f"But got env with action space: {env.action_space.__class__}.",
|
||||
)
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
state_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||
if not state_shape:
|
||||
raise ValueError("Observation space shape is not defined")
|
||||
action_shape = env.action_space.shape
|
||||
max_action = env.action_space.high[0]
|
||||
return state_shape, action_shape, max_action
|
||||
|
||||
def get_action_shape(self) -> TShape:
|
||||
def get_action_shape(self) -> TActionShape:
|
||||
return self.action_shape
|
||||
|
||||
def get_observation_shape(self) -> TShape:
|
||||
def get_observation_shape(self) -> TObservationShape:
|
||||
return self.state_shape
|
||||
|
||||
def get_type(self) -> EnvType:
|
||||
@ -98,15 +99,15 @@ class ContinuousEnvironments(Environments):
|
||||
|
||||
|
||||
class DiscreteEnvironments(Environments):
|
||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
self.observation_shape = env.observation_space.shape or env.observation_space.n
|
||||
self.action_shape = env.action_space.shape or env.action_space.n
|
||||
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
|
||||
|
||||
def get_action_shape(self) -> TShape:
|
||||
def get_action_shape(self) -> TActionShape:
|
||||
return self.action_shape
|
||||
|
||||
def get_observation_shape(self) -> TShape:
|
||||
def get_observation_shape(self) -> TObservationShape:
|
||||
return self.observation_shape
|
||||
|
||||
def get_type(self) -> EnvType:
|
||||
|
@ -40,7 +40,8 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerCallbacks,
|
||||
TrainerEpochCallback,
|
||||
TrainerEpochCallbackTest,
|
||||
TrainerEpochCallbackTrain,
|
||||
TrainerStopCallback,
|
||||
)
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -135,24 +136,29 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
result = trainer.run()
|
||||
pprint(result) # TODO logging
|
||||
|
||||
render = self.config.render
|
||||
if render is None:
|
||||
render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode?
|
||||
self._watch_agent(
|
||||
self.config.watch_num_episodes,
|
||||
policy,
|
||||
test_collector,
|
||||
self.config.render,
|
||||
render,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
|
||||
def _watch_agent(
|
||||
num_episodes: int,
|
||||
policy: BasePolicy,
|
||||
test_collector: Collector,
|
||||
render: float,
|
||||
) -> None:
|
||||
policy.eval()
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
||||
|
||||
|
||||
class ExperimentBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
@ -177,7 +183,7 @@ class ExperimentBuilder:
|
||||
self._env_config = config
|
||||
return self
|
||||
|
||||
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
|
||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||
self._logger_factory = logger_factory
|
||||
return self
|
||||
|
||||
@ -185,16 +191,16 @@ class ExperimentBuilder:
|
||||
self._policy_wrapper_factory = policy_wrapper_factory
|
||||
return self
|
||||
|
||||
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
|
||||
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
|
||||
self._optim_factory = optim_factory
|
||||
return self
|
||||
|
||||
def with_optim_factory_default(
|
||||
self: TBuilder,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-08,
|
||||
weight_decay=0,
|
||||
) -> TBuilder:
|
||||
self,
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-08,
|
||||
weight_decay: float = 0,
|
||||
) -> Self:
|
||||
"""Configures the use of the default optimizer, Adam, with the given parameters.
|
||||
|
||||
:param betas: coefficients used for computing running averages of gradient and its square
|
||||
@ -205,11 +211,11 @@ class ExperimentBuilder:
|
||||
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self:
|
||||
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
|
||||
self._trainer_callbacks.epoch_callback_train = callback
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self:
|
||||
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
|
||||
self._trainer_callbacks.epoch_callback_test = callback
|
||||
return self
|
||||
|
||||
@ -232,7 +238,7 @@ class ExperimentBuilder:
|
||||
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
||||
if self._policy_wrapper_factory:
|
||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||
experiment = Experiment(
|
||||
experiment: Experiment = Experiment(
|
||||
self._config,
|
||||
self._env_factory,
|
||||
agent_factory,
|
||||
@ -248,18 +254,16 @@ class _BuilderMixinActorFactory:
|
||||
self._continuous_actor_type = continuous_actor_type
|
||||
self._actor_factory: ActorFactory | None = None
|
||||
|
||||
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
|
||||
self: TBuilder | _BuilderMixinActorFactory
|
||||
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
||||
self._actor_factory = actor_factory
|
||||
return self
|
||||
|
||||
def _with_actor_factory_default(
|
||||
self: TBuilder,
|
||||
self,
|
||||
hidden_sizes: Sequence[int],
|
||||
continuous_unbounded=False,
|
||||
continuous_conditioned_sigma=False,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | _BuilderMixinActorFactory
|
||||
continuous_unbounded: bool = False,
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
) -> Self:
|
||||
self._actor_factory = ActorFactoryDefault(
|
||||
self._continuous_actor_type,
|
||||
hidden_sizes,
|
||||
@ -268,7 +272,7 @@ class _BuilderMixinActorFactory:
|
||||
)
|
||||
return self
|
||||
|
||||
def _get_actor_factory(self):
|
||||
def _get_actor_factory(self) -> ActorFactory:
|
||||
if self._actor_factory is None:
|
||||
return ActorFactoryDefault(self._continuous_actor_type)
|
||||
else:
|
||||
@ -278,14 +282,14 @@ class _BuilderMixinActorFactory:
|
||||
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||
|
||||
def with_actor_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int],
|
||||
continuous_unbounded=False,
|
||||
continuous_conditioned_sigma=False,
|
||||
continuous_unbounded: bool = False,
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
) -> Self:
|
||||
return super()._with_actor_factory_default(
|
||||
hidden_sizes,
|
||||
@ -297,7 +301,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
|
||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ContinuousActorType.DETERMINISTIC)
|
||||
|
||||
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
||||
@ -308,15 +312,15 @@ class _BuilderMixinCriticsFactory:
|
||||
def __init__(self, num_critics: int):
|
||||
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
||||
|
||||
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory):
|
||||
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self:
|
||||
self._critic_factories[idx] = critic_factory
|
||||
return self
|
||||
|
||||
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
|
||||
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]) -> Self:
|
||||
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
|
||||
return self
|
||||
|
||||
def _get_critic_factory(self, idx: int):
|
||||
def _get_critic_factory(self, idx: int) -> CriticFactory:
|
||||
factory = self._critic_factories[idx]
|
||||
if factory is None:
|
||||
return CriticFactoryDefault()
|
||||
@ -325,7 +329,7 @@ class _BuilderMixinCriticsFactory:
|
||||
|
||||
|
||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(1)
|
||||
|
||||
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
@ -341,7 +345,7 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
|
||||
|
||||
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._critic_use_actor_module = False
|
||||
|
||||
@ -352,11 +356,10 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto
|
||||
|
||||
|
||||
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(2)
|
||||
|
||||
def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
for i in range(len(self._critic_factories)):
|
||||
self._with_critic_factory(i, critic_factory)
|
||||
return self
|
||||
@ -364,35 +367,30 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def with_common_critic_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
) -> Self:
|
||||
for i in range(len(self._critic_factories)):
|
||||
self._with_critic_factory_default(i, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
self._with_critic_factory(0, critic_factory)
|
||||
return self
|
||||
|
||||
def with_critic1_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
) -> Self:
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
self._with_critic_factory(1, critic_factory)
|
||||
return self
|
||||
|
||||
def with_critic2_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
) -> Self:
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
|
@ -19,7 +19,7 @@ class Logger:
|
||||
|
||||
class LoggerFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
||||
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
||||
pass
|
||||
|
||||
|
||||
@ -48,6 +48,7 @@ class DefaultLoggerFactory(LoggerFactory):
|
||||
),
|
||||
),
|
||||
)
|
||||
logger: TLogger
|
||||
if self.logger_type == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
|
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -11,7 +12,7 @@ from tianshou.utils.net.common import BaseActor, Net
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class ContinuousActorType:
|
||||
class ContinuousActorType(Enum):
|
||||
GAUSSIAN = "gaussian"
|
||||
DETERMINISTIC = "deterministic"
|
||||
UNSUPPORTED = "unsupported"
|
||||
@ -23,7 +24,7 @@ class ActorFactory(ToStringMixin, ABC):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _init_linear(actor: torch.nn.Module):
|
||||
def _init_linear(actor: torch.nn.Module) -> None:
|
||||
"""Initializes linear layers of an actor module using default mechanisms.
|
||||
|
||||
:param module: the actor module.
|
||||
@ -34,7 +35,7 @@ class ActorFactory(ToStringMixin, ABC):
|
||||
# do last policy layer scaling, this will make initial actions have (close to)
|
||||
# 0 mean and std, and will help boost performances,
|
||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||
for m in actor.mu.modules():
|
||||
for m in actor.mu.modules(): # type: ignore
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
@ -48,8 +49,8 @@ class ActorFactoryDefault(ActorFactory):
|
||||
self,
|
||||
continuous_actor_type: ContinuousActorType,
|
||||
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||
continuous_unbounded=False,
|
||||
continuous_conditioned_sigma=False,
|
||||
continuous_unbounded: bool = False,
|
||||
continuous_conditioned_sigma: bool = False,
|
||||
):
|
||||
self.continuous_actor_type = continuous_actor_type
|
||||
self.continuous_unbounded = continuous_unbounded
|
||||
@ -58,6 +59,7 @@ class ActorFactoryDefault(ActorFactory):
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
env_type = envs.get_type()
|
||||
factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
match self.continuous_actor_type:
|
||||
case ContinuousActorType.GAUSSIAN:
|
||||
@ -103,7 +105,12 @@ class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
|
||||
|
||||
|
||||
class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
||||
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_sizes: Sequence[int],
|
||||
unbounded: bool = True,
|
||||
conditioned_sigma: bool = False,
|
||||
):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.unbounded = unbounded
|
||||
self.conditioned_sigma = conditioned_sigma
|
||||
|
@ -10,10 +10,10 @@ from tianshou.highlevel.env import Environments
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TDevice: TypeAlias = str | int | torch.device
|
||||
TDevice: TypeAlias = str | torch.device
|
||||
|
||||
|
||||
def init_linear_orthogonal(module: torch.nn.Module):
|
||||
def init_linear_orthogonal(module: torch.nn.Module) -> None:
|
||||
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
||||
|
||||
:param module: the module whose submodules are to be processed
|
||||
|
@ -27,11 +27,17 @@ class CriticFactoryDefault(CriticFactory):
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
env_type = envs.get_type()
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
||||
return factory.create_module(envs, device, use_action)
|
||||
return CriticFactoryContinuousNet(self.hidden_sizes).create_module(
|
||||
envs,
|
||||
device,
|
||||
use_action,
|
||||
)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
||||
return factory.create_module(envs, device, use_action)
|
||||
return CriticFactoryDiscreteNet(self.hidden_sizes).create_module(
|
||||
envs,
|
||||
device,
|
||||
use_action,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
|
||||
|
@ -23,11 +23,11 @@ class ActorCriticModuleOpt:
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
@property
|
||||
def actor(self):
|
||||
def actor(self) -> torch.nn.Module:
|
||||
return self.actor_critic_module.actor
|
||||
|
||||
@property
|
||||
def critic(self):
|
||||
def critic(self) -> torch.nn.Module:
|
||||
return self.actor_critic_module.critic
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Protocol
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam, RMSprop
|
||||
@ -7,6 +7,11 @@ from torch.optim import Adam, RMSprop
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class OptimizerWithLearningRateProtocol(Protocol):
|
||||
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
|
||||
pass
|
||||
|
||||
|
||||
class OptimizerFactory(ABC, ToStringMixin):
|
||||
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
|
||||
# Right now, the learning rate is typically a configuration parameter.
|
||||
@ -18,7 +23,7 @@ class OptimizerFactory(ABC, ToStringMixin):
|
||||
|
||||
|
||||
class OptimizerFactoryTorch(OptimizerFactory):
|
||||
def __init__(self, optim_class: Any, **kwargs):
|
||||
def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any):
|
||||
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
||||
which will be passed the module parameters, the learning rate as `lr` and the
|
||||
kwargs provided.
|
||||
@ -32,7 +37,12 @@ class OptimizerFactoryTorch(OptimizerFactory):
|
||||
|
||||
|
||||
class OptimizerFactoryAdam(OptimizerFactory):
|
||||
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
|
||||
def __init__(
|
||||
self,
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-08,
|
||||
weight_decay: float = 0,
|
||||
):
|
||||
self.weight_decay = weight_decay
|
||||
self.eps = eps
|
||||
self.betas = betas
|
||||
@ -48,7 +58,14 @@ class OptimizerFactoryAdam(OptimizerFactory):
|
||||
|
||||
|
||||
class OptimizerFactoryRMSprop(OptimizerFactory):
|
||||
def __init__(self, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False):
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float = 0.99,
|
||||
eps: float = 1e-08,
|
||||
weight_decay: float = 0,
|
||||
momentum: float = 0,
|
||||
centered: bool = False,
|
||||
):
|
||||
self.alpha = alpha
|
||||
self.momentum = momentum
|
||||
self.centered = centered
|
||||
|
@ -30,7 +30,7 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
||||
optim_factory: OptimizerFactory,
|
||||
device: TDevice,
|
||||
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||
target_entropy = -np.prod(envs.get_action_shape())
|
||||
target_entropy = float(-np.prod(envs.get_action_shape()))
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
||||
return target_entropy, log_alpha, alpha_optim
|
||||
|
@ -19,21 +19,21 @@ class DistributionFunctionFactory(ToStringMixin, ABC):
|
||||
|
||||
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
assert envs.get_type().assert_discrete(self)
|
||||
envs.get_type().assert_discrete(self)
|
||||
return self._dist_fn
|
||||
|
||||
@staticmethod
|
||||
def _dist_fn(p):
|
||||
def _dist_fn(p: TDistParams) -> torch.distributions.Distribution:
|
||||
return torch.distributions.Categorical(logits=p)
|
||||
|
||||
|
||||
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
assert envs.get_type().assert_continuous(self)
|
||||
envs.get_type().assert_continuous(self)
|
||||
return self._dist_fn
|
||||
|
||||
@staticmethod
|
||||
def _dist_fn(*p):
|
||||
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
|
||||
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@ from dataclasses import asdict, dataclass
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.highlevel.env import Environments
|
||||
@ -71,7 +72,7 @@ class ParamTransformerChangeValue(ParamTransformer):
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData):
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
params[self.key] = self.change_value(params[self.key], data)
|
||||
|
||||
@abstractmethod
|
||||
@ -118,6 +119,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||
)
|
||||
if lr_scheduler_factory is not None:
|
||||
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
||||
lr_scheduler: LRScheduler | MultipleLRSchedulers | None
|
||||
match len(lr_schedulers):
|
||||
case 0:
|
||||
lr_scheduler = None
|
||||
@ -140,6 +142,7 @@ class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
|
||||
self.key_scheduler = key_scheduler
|
||||
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
assert data.actor is not None and data.critic1 is not None
|
||||
transformer = ParamTransformerMultiLRScheduler(
|
||||
[
|
||||
(data.actor.optim, self.key_factory_actor),
|
||||
@ -164,6 +167,7 @@ class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
|
||||
self.key_scheduler = key_scheduler
|
||||
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
assert data.actor is not None and data.critic1 is not None and data.critic2 is not None
|
||||
transformer = ParamTransformerMultiLRScheduler(
|
||||
[
|
||||
(data.actor.optim, self.key_factory_actor),
|
||||
@ -247,7 +251,7 @@ class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
||||
lr: float = 1e-3
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
def _get_param_transformers(self):
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
return [
|
||||
ParamTransformerDrop("lr"),
|
||||
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
|
||||
@ -261,7 +265,7 @@ class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
|
||||
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
def _get_param_transformers(self):
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
return [
|
||||
ParamTransformerDrop("actor_lr", "critic_lr"),
|
||||
ParamTransformerActorAndCriticLRScheduler(
|
||||
@ -325,7 +329,7 @@ class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
def _get_param_transformers(self):
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
return [
|
||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||
ParamTransformerActorDualCriticsLRScheduler(
|
||||
@ -348,7 +352,7 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
|
||||
def _get_param_transformers(self):
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
||||
@ -365,7 +369,7 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
||||
is_double: bool = True
|
||||
clip_loss_grad: bool = False
|
||||
|
||||
def _get_param_transformers(self):
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
||||
return transformers
|
||||
@ -380,7 +384,7 @@ class DDPGParams(Params, ParamsMixinActorAndCritic):
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
|
||||
def _get_param_transformers(self):
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
|
||||
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
||||
@ -399,7 +403,7 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
|
||||
def _get_param_transformers(self):
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
||||
|
@ -36,7 +36,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
lr: float,
|
||||
lr_scale: float,
|
||||
reward_scale: float,
|
||||
forward_loss_weight,
|
||||
forward_loss_weight: float,
|
||||
):
|
||||
self.feature_net_factory = feature_net_factory
|
||||
self.hidden_sizes = hidden_sizes
|
||||
@ -54,6 +54,8 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
) -> ICMPolicy:
|
||||
feature_net = self.feature_net_factory.create_module(envs, device)
|
||||
action_dim = envs.get_action_shape()
|
||||
if type(action_dim) != int:
|
||||
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
|
||||
feature_dim = feature_net.output_dim
|
||||
icm_net = IntrinsicCuriosityModule(
|
||||
feature_net.module,
|
||||
|
@ -18,7 +18,7 @@ class TrainingContext:
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class TrainerEpochCallback(ToStringMixin, ABC):
|
||||
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of each epoch."""
|
||||
|
||||
@abstractmethod
|
||||
@ -26,7 +26,21 @@ class TrainerEpochCallback(ToStringMixin, ABC):
|
||||
pass
|
||||
|
||||
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
|
||||
def fn(epoch, env_step):
|
||||
def fn(epoch: int, env_step: int) -> None:
|
||||
return self.callback(epoch, env_step, context)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
class TrainerEpochCallbackTest(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of each epoch."""
|
||||
|
||||
@abstractmethod
|
||||
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||
pass
|
||||
|
||||
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]:
|
||||
def fn(epoch: int, env_step: int | None) -> None:
|
||||
return self.callback(epoch, env_step, context)
|
||||
|
||||
return fn
|
||||
@ -42,7 +56,7 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
||||
"""
|
||||
|
||||
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
|
||||
def fn(mean_rewards: float):
|
||||
def fn(mean_rewards: float) -> bool:
|
||||
return self.should_stop(mean_rewards, context)
|
||||
|
||||
return fn
|
||||
@ -50,6 +64,6 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
||||
|
||||
@dataclass
|
||||
class TrainerCallbacks:
|
||||
epoch_callback_train: TrainerEpochCallback | None = None
|
||||
epoch_callback_test: TrainerEpochCallback | None = None
|
||||
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
||||
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
||||
stop_callback: TrainerStopCallback | None = None
|
||||
|
@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal, TypeAlias, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -17,7 +17,7 @@ from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.utils import RunningMeanStd
|
||||
|
||||
TDistParams = torch.Tensor | tuple[torch.Tensor]
|
||||
TDistParams: TypeAlias = torch.Tensor | [torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class PGPolicy(BasePolicy):
|
||||
|
@ -15,7 +15,7 @@ class MultipleLRSchedulers:
|
||||
policy = PPOPolicy(..., lr_scheduler=scheduler)
|
||||
"""
|
||||
|
||||
def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR):
|
||||
def __init__(self, *args: torch.optim.lr_scheduler.LRScheduler):
|
||||
self.schedulers = args
|
||||
|
||||
def step(self) -> None:
|
||||
|
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, no_type_check
|
||||
from typing import Any, TypeAlias, no_type_check
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -11,6 +11,7 @@ from tianshou.data.types import RecurrentStateBatch
|
||||
|
||||
ModuleType = type[nn.Module]
|
||||
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
||||
TActionShape: TypeAlias = Sequence[int] | int
|
||||
|
||||
|
||||
def miniblock(
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.utils.net.common import MLP, BaseActor
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
|
||||
|
||||
SIGMA_MIN = -20
|
||||
SIGMA_MAX = 2
|
||||
@ -40,7 +40,7 @@ class Actor(BaseActor):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
action_shape: TActionShape,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
max_action: float = 1.0,
|
||||
device: str | int | torch.device = "cpu",
|
||||
@ -180,7 +180,7 @@ class ActorProb(BaseActor):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
action_shape: TActionShape,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
max_action: float = 1.0,
|
||||
device: str | int | torch.device = "cpu",
|
||||
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.utils.net.common import MLP, BaseActor
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
|
||||
|
||||
|
||||
class Actor(BaseActor):
|
||||
@ -39,7 +39,7 @@ class Actor(BaseActor):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
action_shape: TActionShape,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
softmax_output: bool = True,
|
||||
preprocess_net_output_dim: int | None = None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user