Improve type annotations, fix type issues and add checks
This commit is contained in:
parent
e6716326bd
commit
a161a9cf58
@ -19,7 +19,11 @@ from tianshou.highlevel.params.policy_params import DQNParams
|
|||||||
from tianshou.highlevel.params.policy_wrapper import (
|
from tianshou.highlevel.params.policy_wrapper import (
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.trainer import TrainerEpochCallback, TrainingContext
|
from tianshou.highlevel.trainer import (
|
||||||
|
TrainerEpochCallbackTest,
|
||||||
|
TrainerEpochCallbackTrain,
|
||||||
|
TrainingContext,
|
||||||
|
)
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
|
|
||||||
@ -75,7 +79,7 @@ def main(
|
|||||||
scale=scale_obs,
|
scale=scale_obs,
|
||||||
)
|
)
|
||||||
|
|
||||||
class TrainEpochCallback(TrainerEpochCallback):
|
class TrainEpochCallback(TrainerEpochCallbackTrain):
|
||||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
policy: DQNPolicy = context.policy
|
policy: DQNPolicy = context.policy
|
||||||
logger = context.logger.logger
|
logger = context.logger.logger
|
||||||
@ -88,7 +92,7 @@ def main(
|
|||||||
if env_step % 1000 == 0:
|
if env_step % 1000 == 0:
|
||||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
class TestEpochCallback(TrainerEpochCallback):
|
class TestEpochCallback(TrainerEpochCallbackTest):
|
||||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
policy: DQNPolicy = context.policy
|
policy: DQNPolicy = context.policy
|
||||||
policy.set_eps(eps_test)
|
policy.set_eps(eps_test)
|
||||||
|
@ -13,7 +13,9 @@ from tianshou.highlevel.experiment import (
|
|||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryIndependentGaussians
|
from tianshou.highlevel.params.dist_fn import (
|
||||||
|
DistributionFunctionFactoryIndependentGaussians,
|
||||||
|
)
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Generic, TypeVar
|
from os import PathLike
|
||||||
|
from typing import Any, Generic, TypeVar, cast
|
||||||
|
|
||||||
import gymnasium
|
import gymnasium
|
||||||
import torch
|
import torch
|
||||||
@ -60,9 +61,14 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||||
|
|
||||||
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
def create_train_test_collector(
|
||||||
|
self,
|
||||||
|
policy: BasePolicy,
|
||||||
|
envs: Environments,
|
||||||
|
) -> tuple[Collector, Collector]:
|
||||||
buffer_size = self.sampling_config.buffer_size
|
buffer_size = self.sampling_config.buffer_size
|
||||||
train_envs = envs.train_envs
|
train_envs = envs.train_envs
|
||||||
|
buffer: ReplayBuffer
|
||||||
if len(train_envs) > 1:
|
if len(train_envs) > 1:
|
||||||
buffer = VectorReplayBuffer(
|
buffer = VectorReplayBuffer(
|
||||||
buffer_size,
|
buffer_size,
|
||||||
@ -90,7 +96,7 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.policy_wrapper_factory = policy_wrapper_factory
|
self.policy_wrapper_factory = policy_wrapper_factory
|
||||||
|
|
||||||
def set_trainer_callbacks(self, callbacks: TrainerCallbacks):
|
def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None:
|
||||||
self.trainer_callbacks = callbacks
|
self.trainer_callbacks = callbacks
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -122,7 +128,12 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
return save_best_fn
|
return save_best_fn
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
|
def load_checkpoint(
|
||||||
|
policy: torch.nn.Module,
|
||||||
|
path: str | PathLike,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
) -> None:
|
||||||
ckpt = torch.load(path, map_location=device)
|
ckpt = torch.load(path, map_location=device)
|
||||||
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
||||||
if envs.train_envs:
|
if envs.train_envs:
|
||||||
@ -280,6 +291,7 @@ class _ActorCriticMixin:
|
|||||||
lr: float,
|
lr: float,
|
||||||
) -> ActorCriticModuleOpt:
|
) -> ActorCriticModuleOpt:
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
|
critic: torch.nn.Module
|
||||||
if self.critic_use_actor_module:
|
if self.critic_use_actor_module:
|
||||||
if self.critic_use_action:
|
if self.critic_use_action:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -375,7 +387,7 @@ class ActorCriticAgentFactory(
|
|||||||
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _create_kwargs(self, envs: Environments, device: TDevice):
|
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
|
||||||
actor_critic = self._create_actor_critic(envs, device)
|
actor_critic = self._create_actor_critic(envs, device)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerData(
|
ParamTransformerData(
|
||||||
@ -468,8 +480,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
envs.get_type().assert_discrete(self)
|
envs.get_type().assert_discrete(self)
|
||||||
# noinspection PyTypeChecker
|
action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space())
|
||||||
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
|
|
||||||
return DQNPolicy(
|
return DQNPolicy(
|
||||||
model=model,
|
model=model,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
|
@ -1,37 +1,38 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||||
|
from tianshou.utils.net.common import TActionShape
|
||||||
|
|
||||||
TShape = int | Sequence[int]
|
TObservationShape: TypeAlias = int | Sequence[int]
|
||||||
|
|
||||||
|
|
||||||
class EnvType(Enum):
|
class EnvType(Enum):
|
||||||
CONTINUOUS = "continuous"
|
CONTINUOUS = "continuous"
|
||||||
DISCRETE = "discrete"
|
DISCRETE = "discrete"
|
||||||
|
|
||||||
def is_discrete(self):
|
def is_discrete(self) -> bool:
|
||||||
return self == EnvType.DISCRETE
|
return self == EnvType.DISCRETE
|
||||||
|
|
||||||
def is_continuous(self):
|
def is_continuous(self) -> bool:
|
||||||
return self == EnvType.CONTINUOUS
|
return self == EnvType.CONTINUOUS
|
||||||
|
|
||||||
def assert_continuous(self, requiring_entity: Any):
|
def assert_continuous(self, requiring_entity: Any) -> None:
|
||||||
if not self.is_continuous():
|
if not self.is_continuous():
|
||||||
raise AssertionError(f"{requiring_entity} requires continuous environments")
|
raise AssertionError(f"{requiring_entity} requires continuous environments")
|
||||||
|
|
||||||
def assert_discrete(self, requiring_entity: Any):
|
def assert_discrete(self, requiring_entity: Any) -> None:
|
||||||
if not self.is_discrete():
|
if not self.is_discrete():
|
||||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||||
|
|
||||||
|
|
||||||
class Environments(ABC):
|
class Environments(ABC):
|
||||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.train_envs = train_envs
|
self.train_envs = train_envs
|
||||||
self.test_envs = test_envs
|
self.test_envs = test_envs
|
||||||
@ -43,11 +44,11 @@ class Environments(ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_action_shape(self) -> TShape:
|
def get_action_shape(self) -> TActionShape:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_observation_shape(self) -> TShape:
|
def get_observation_shape(self) -> TObservationShape:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_action_space(self) -> gym.Space:
|
def get_action_space(self) -> gym.Space:
|
||||||
@ -62,11 +63,11 @@ class Environments(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ContinuousEnvironments(Environments):
|
class ContinuousEnvironments(Environments):
|
||||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
super().__init__(env, train_envs, test_envs)
|
super().__init__(env, train_envs, test_envs)
|
||||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||||
|
|
||||||
def info(self):
|
def info(self) -> dict[str, Any]:
|
||||||
d = super().info()
|
d = super().info()
|
||||||
d["max_action"] = self.max_action
|
d["max_action"] = self.max_action
|
||||||
return d
|
return d
|
||||||
@ -80,17 +81,17 @@ class ContinuousEnvironments(Environments):
|
|||||||
"Only environments with continuous action space are supported here. "
|
"Only environments with continuous action space are supported here. "
|
||||||
f"But got env with action space: {env.action_space.__class__}.",
|
f"But got env with action space: {env.action_space.__class__}.",
|
||||||
)
|
)
|
||||||
state_shape = env.observation_space.shape or env.observation_space.n
|
state_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||||
if not state_shape:
|
if not state_shape:
|
||||||
raise ValueError("Observation space shape is not defined")
|
raise ValueError("Observation space shape is not defined")
|
||||||
action_shape = env.action_space.shape
|
action_shape = env.action_space.shape
|
||||||
max_action = env.action_space.high[0]
|
max_action = env.action_space.high[0]
|
||||||
return state_shape, action_shape, max_action
|
return state_shape, action_shape, max_action
|
||||||
|
|
||||||
def get_action_shape(self) -> TShape:
|
def get_action_shape(self) -> TActionShape:
|
||||||
return self.action_shape
|
return self.action_shape
|
||||||
|
|
||||||
def get_observation_shape(self) -> TShape:
|
def get_observation_shape(self) -> TObservationShape:
|
||||||
return self.state_shape
|
return self.state_shape
|
||||||
|
|
||||||
def get_type(self) -> EnvType:
|
def get_type(self) -> EnvType:
|
||||||
@ -98,15 +99,15 @@ class ContinuousEnvironments(Environments):
|
|||||||
|
|
||||||
|
|
||||||
class DiscreteEnvironments(Environments):
|
class DiscreteEnvironments(Environments):
|
||||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
super().__init__(env, train_envs, test_envs)
|
super().__init__(env, train_envs, test_envs)
|
||||||
self.observation_shape = env.observation_space.shape or env.observation_space.n
|
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||||
self.action_shape = env.action_space.shape or env.action_space.n
|
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
|
||||||
|
|
||||||
def get_action_shape(self) -> TShape:
|
def get_action_shape(self) -> TActionShape:
|
||||||
return self.action_shape
|
return self.action_shape
|
||||||
|
|
||||||
def get_observation_shape(self) -> TShape:
|
def get_observation_shape(self) -> TObservationShape:
|
||||||
return self.observation_shape
|
return self.observation_shape
|
||||||
|
|
||||||
def get_type(self) -> EnvType:
|
def get_type(self) -> EnvType:
|
||||||
|
@ -40,7 +40,8 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
|||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||||
from tianshou.highlevel.trainer import (
|
from tianshou.highlevel.trainer import (
|
||||||
TrainerCallbacks,
|
TrainerCallbacks,
|
||||||
TrainerEpochCallback,
|
TrainerEpochCallbackTest,
|
||||||
|
TrainerEpochCallbackTrain,
|
||||||
TrainerStopCallback,
|
TrainerStopCallback,
|
||||||
)
|
)
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
@ -135,24 +136,29 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
result = trainer.run()
|
result = trainer.run()
|
||||||
pprint(result) # TODO logging
|
pprint(result) # TODO logging
|
||||||
|
|
||||||
|
render = self.config.render
|
||||||
|
if render is None:
|
||||||
|
render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode?
|
||||||
self._watch_agent(
|
self._watch_agent(
|
||||||
self.config.watch_num_episodes,
|
self.config.watch_num_episodes,
|
||||||
policy,
|
policy,
|
||||||
test_collector,
|
test_collector,
|
||||||
self.config.render,
|
render,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
|
def _watch_agent(
|
||||||
|
num_episodes: int,
|
||||||
|
policy: BasePolicy,
|
||||||
|
test_collector: Collector,
|
||||||
|
render: float,
|
||||||
|
) -> None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||||
|
|
||||||
|
|
||||||
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
|
||||||
|
|
||||||
|
|
||||||
class ExperimentBuilder:
|
class ExperimentBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -177,7 +183,7 @@ class ExperimentBuilder:
|
|||||||
self._env_config = config
|
self._env_config = config
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
|
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||||
self._logger_factory = logger_factory
|
self._logger_factory = logger_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -185,16 +191,16 @@ class ExperimentBuilder:
|
|||||||
self._policy_wrapper_factory = policy_wrapper_factory
|
self._policy_wrapper_factory = policy_wrapper_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
|
def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self:
|
||||||
self._optim_factory = optim_factory
|
self._optim_factory = optim_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_optim_factory_default(
|
def with_optim_factory_default(
|
||||||
self: TBuilder,
|
self,
|
||||||
betas=(0.9, 0.999),
|
betas: tuple[float, float] = (0.9, 0.999),
|
||||||
eps=1e-08,
|
eps: float = 1e-08,
|
||||||
weight_decay=0,
|
weight_decay: float = 0,
|
||||||
) -> TBuilder:
|
) -> Self:
|
||||||
"""Configures the use of the default optimizer, Adam, with the given parameters.
|
"""Configures the use of the default optimizer, Adam, with the given parameters.
|
||||||
|
|
||||||
:param betas: coefficients used for computing running averages of gradient and its square
|
:param betas: coefficients used for computing running averages of gradient and its square
|
||||||
@ -205,11 +211,11 @@ class ExperimentBuilder:
|
|||||||
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self:
|
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
|
||||||
self._trainer_callbacks.epoch_callback_train = callback
|
self._trainer_callbacks.epoch_callback_train = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self:
|
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
|
||||||
self._trainer_callbacks.epoch_callback_test = callback
|
self._trainer_callbacks.epoch_callback_test = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -232,7 +238,7 @@ class ExperimentBuilder:
|
|||||||
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
||||||
if self._policy_wrapper_factory:
|
if self._policy_wrapper_factory:
|
||||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||||
experiment = Experiment(
|
experiment: Experiment = Experiment(
|
||||||
self._config,
|
self._config,
|
||||||
self._env_factory,
|
self._env_factory,
|
||||||
agent_factory,
|
agent_factory,
|
||||||
@ -248,18 +254,16 @@ class _BuilderMixinActorFactory:
|
|||||||
self._continuous_actor_type = continuous_actor_type
|
self._continuous_actor_type = continuous_actor_type
|
||||||
self._actor_factory: ActorFactory | None = None
|
self._actor_factory: ActorFactory | None = None
|
||||||
|
|
||||||
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
|
def with_actor_factory(self, actor_factory: ActorFactory) -> Self:
|
||||||
self: TBuilder | _BuilderMixinActorFactory
|
|
||||||
self._actor_factory = actor_factory
|
self._actor_factory = actor_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _with_actor_factory_default(
|
def _with_actor_factory_default(
|
||||||
self: TBuilder,
|
self,
|
||||||
hidden_sizes: Sequence[int],
|
hidden_sizes: Sequence[int],
|
||||||
continuous_unbounded=False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma=False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
) -> TBuilder:
|
) -> Self:
|
||||||
self: TBuilder | _BuilderMixinActorFactory
|
|
||||||
self._actor_factory = ActorFactoryDefault(
|
self._actor_factory = ActorFactoryDefault(
|
||||||
self._continuous_actor_type,
|
self._continuous_actor_type,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
@ -268,7 +272,7 @@ class _BuilderMixinActorFactory:
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _get_actor_factory(self):
|
def _get_actor_factory(self) -> ActorFactory:
|
||||||
if self._actor_factory is None:
|
if self._actor_factory is None:
|
||||||
return ActorFactoryDefault(self._continuous_actor_type)
|
return ActorFactoryDefault(self._continuous_actor_type)
|
||||||
else:
|
else:
|
||||||
@ -278,14 +282,14 @@ class _BuilderMixinActorFactory:
|
|||||||
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(ContinuousActorType.GAUSSIAN)
|
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||||
|
|
||||||
def with_actor_factory_default(
|
def with_actor_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int],
|
hidden_sizes: Sequence[int],
|
||||||
continuous_unbounded=False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma=False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
return super()._with_actor_factory_default(
|
return super()._with_actor_factory_default(
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
@ -297,7 +301,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
|||||||
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
|
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
|
||||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(ContinuousActorType.DETERMINISTIC)
|
super().__init__(ContinuousActorType.DETERMINISTIC)
|
||||||
|
|
||||||
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
||||||
@ -308,15 +312,15 @@ class _BuilderMixinCriticsFactory:
|
|||||||
def __init__(self, num_critics: int):
|
def __init__(self, num_critics: int):
|
||||||
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
||||||
|
|
||||||
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory):
|
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self:
|
||||||
self._critic_factories[idx] = critic_factory
|
self._critic_factories[idx] = critic_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
|
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]) -> Self:
|
||||||
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
|
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _get_critic_factory(self, idx: int):
|
def _get_critic_factory(self, idx: int) -> CriticFactory:
|
||||||
factory = self._critic_factories[idx]
|
factory = self._critic_factories[idx]
|
||||||
if factory is None:
|
if factory is None:
|
||||||
return CriticFactoryDefault()
|
return CriticFactoryDefault()
|
||||||
@ -325,7 +329,7 @@ class _BuilderMixinCriticsFactory:
|
|||||||
|
|
||||||
|
|
||||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(1)
|
super().__init__(1)
|
||||||
|
|
||||||
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
@ -341,7 +345,7 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
|
|
||||||
|
|
||||||
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
|
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._critic_use_actor_module = False
|
self._critic_use_actor_module = False
|
||||||
|
|
||||||
@ -352,11 +356,10 @@ class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFacto
|
|||||||
|
|
||||||
|
|
||||||
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(2)
|
super().__init__(2)
|
||||||
|
|
||||||
def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
||||||
for i in range(len(self._critic_factories)):
|
for i in range(len(self._critic_factories)):
|
||||||
self._with_critic_factory(i, critic_factory)
|
self._with_critic_factory(i, critic_factory)
|
||||||
return self
|
return self
|
||||||
@ -364,35 +367,30 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
def with_common_critic_factory_default(
|
def with_common_critic_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> TBuilder:
|
) -> Self:
|
||||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
||||||
for i in range(len(self._critic_factories)):
|
for i in range(len(self._critic_factories)):
|
||||||
self._with_critic_factory_default(i, hidden_sizes)
|
self._with_critic_factory_default(i, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
def with_critic1_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
||||||
self._with_critic_factory(0, critic_factory)
|
self._with_critic_factory(0, critic_factory)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic1_factory_default(
|
def with_critic1_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> TBuilder:
|
) -> Self:
|
||||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
def with_critic2_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
||||||
self._with_critic_factory(1, critic_factory)
|
self._with_critic_factory(1, critic_factory)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic2_factory_default(
|
def with_critic2_factory_default(
|
||||||
self,
|
self,
|
||||||
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> TBuilder:
|
) -> Self:
|
||||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class Logger:
|
|||||||
|
|
||||||
class LoggerFactory(ToStringMixin, ABC):
|
class LoggerFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -48,6 +48,7 @@ class DefaultLoggerFactory(LoggerFactory):
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
logger: TLogger
|
||||||
if self.logger_type == "wandb":
|
if self.logger_type == "wandb":
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
save_interval=1,
|
save_interval=1,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -11,7 +12,7 @@ from tianshou.utils.net.common import BaseActor, Net
|
|||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
class ContinuousActorType:
|
class ContinuousActorType(Enum):
|
||||||
GAUSSIAN = "gaussian"
|
GAUSSIAN = "gaussian"
|
||||||
DETERMINISTIC = "deterministic"
|
DETERMINISTIC = "deterministic"
|
||||||
UNSUPPORTED = "unsupported"
|
UNSUPPORTED = "unsupported"
|
||||||
@ -23,7 +24,7 @@ class ActorFactory(ToStringMixin, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_linear(actor: torch.nn.Module):
|
def _init_linear(actor: torch.nn.Module) -> None:
|
||||||
"""Initializes linear layers of an actor module using default mechanisms.
|
"""Initializes linear layers of an actor module using default mechanisms.
|
||||||
|
|
||||||
:param module: the actor module.
|
:param module: the actor module.
|
||||||
@ -34,7 +35,7 @@ class ActorFactory(ToStringMixin, ABC):
|
|||||||
# do last policy layer scaling, this will make initial actions have (close to)
|
# do last policy layer scaling, this will make initial actions have (close to)
|
||||||
# 0 mean and std, and will help boost performances,
|
# 0 mean and std, and will help boost performances,
|
||||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||||
for m in actor.mu.modules():
|
for m in actor.mu.modules(): # type: ignore
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
m.weight.data.copy_(0.01 * m.weight.data)
|
m.weight.data.copy_(0.01 * m.weight.data)
|
||||||
|
|
||||||
@ -48,8 +49,8 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
self,
|
self,
|
||||||
continuous_actor_type: ContinuousActorType,
|
continuous_actor_type: ContinuousActorType,
|
||||||
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||||
continuous_unbounded=False,
|
continuous_unbounded: bool = False,
|
||||||
continuous_conditioned_sigma=False,
|
continuous_conditioned_sigma: bool = False,
|
||||||
):
|
):
|
||||||
self.continuous_actor_type = continuous_actor_type
|
self.continuous_actor_type = continuous_actor_type
|
||||||
self.continuous_unbounded = continuous_unbounded
|
self.continuous_unbounded = continuous_unbounded
|
||||||
@ -58,6 +59,7 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
env_type = envs.get_type()
|
env_type = envs.get_type()
|
||||||
|
factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet
|
||||||
if env_type == EnvType.CONTINUOUS:
|
if env_type == EnvType.CONTINUOUS:
|
||||||
match self.continuous_actor_type:
|
match self.continuous_actor_type:
|
||||||
case ContinuousActorType.GAUSSIAN:
|
case ContinuousActorType.GAUSSIAN:
|
||||||
@ -103,7 +105,12 @@ class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous):
|
|||||||
|
|
||||||
|
|
||||||
class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
|
||||||
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_sizes: Sequence[int],
|
||||||
|
unbounded: bool = True,
|
||||||
|
conditioned_sigma: bool = False,
|
||||||
|
):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
self.unbounded = unbounded
|
self.unbounded = unbounded
|
||||||
self.conditioned_sigma = conditioned_sigma
|
self.conditioned_sigma = conditioned_sigma
|
||||||
|
@ -10,10 +10,10 @@ from tianshou.highlevel.env import Environments
|
|||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TDevice: TypeAlias = str | int | torch.device
|
TDevice: TypeAlias = str | torch.device
|
||||||
|
|
||||||
|
|
||||||
def init_linear_orthogonal(module: torch.nn.Module):
|
def init_linear_orthogonal(module: torch.nn.Module) -> None:
|
||||||
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
||||||
|
|
||||||
:param module: the module whose submodules are to be processed
|
:param module: the module whose submodules are to be processed
|
||||||
|
@ -27,11 +27,17 @@ class CriticFactoryDefault(CriticFactory):
|
|||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
env_type = envs.get_type()
|
env_type = envs.get_type()
|
||||||
if env_type == EnvType.CONTINUOUS:
|
if env_type == EnvType.CONTINUOUS:
|
||||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
return CriticFactoryContinuousNet(self.hidden_sizes).create_module(
|
||||||
return factory.create_module(envs, device, use_action)
|
envs,
|
||||||
|
device,
|
||||||
|
use_action,
|
||||||
|
)
|
||||||
elif env_type == EnvType.DISCRETE:
|
elif env_type == EnvType.DISCRETE:
|
||||||
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
return CriticFactoryDiscreteNet(self.hidden_sizes).create_module(
|
||||||
return factory.create_module(envs, device, use_action)
|
envs,
|
||||||
|
device,
|
||||||
|
use_action,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{env_type} not supported")
|
raise ValueError(f"{env_type} not supported")
|
||||||
|
|
||||||
|
@ -23,11 +23,11 @@ class ActorCriticModuleOpt:
|
|||||||
optim: torch.optim.Optimizer
|
optim: torch.optim.Optimizer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def actor(self):
|
def actor(self) -> torch.nn.Module:
|
||||||
return self.actor_critic_module.actor
|
return self.actor_critic_module.actor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def critic(self):
|
def critic(self) -> torch.nn.Module:
|
||||||
return self.actor_critic_module.critic
|
return self.actor_critic_module.critic
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any, Protocol
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Adam, RMSprop
|
from torch.optim import Adam, RMSprop
|
||||||
@ -7,6 +7,11 @@ from torch.optim import Adam, RMSprop
|
|||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerWithLearningRateProtocol(Protocol):
|
||||||
|
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OptimizerFactory(ABC, ToStringMixin):
|
class OptimizerFactory(ABC, ToStringMixin):
|
||||||
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
|
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
|
||||||
# Right now, the learning rate is typically a configuration parameter.
|
# Right now, the learning rate is typically a configuration parameter.
|
||||||
@ -18,7 +23,7 @@ class OptimizerFactory(ABC, ToStringMixin):
|
|||||||
|
|
||||||
|
|
||||||
class OptimizerFactoryTorch(OptimizerFactory):
|
class OptimizerFactoryTorch(OptimizerFactory):
|
||||||
def __init__(self, optim_class: Any, **kwargs):
|
def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any):
|
||||||
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
""":param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`),
|
||||||
which will be passed the module parameters, the learning rate as `lr` and the
|
which will be passed the module parameters, the learning rate as `lr` and the
|
||||||
kwargs provided.
|
kwargs provided.
|
||||||
@ -32,7 +37,12 @@ class OptimizerFactoryTorch(OptimizerFactory):
|
|||||||
|
|
||||||
|
|
||||||
class OptimizerFactoryAdam(OptimizerFactory):
|
class OptimizerFactoryAdam(OptimizerFactory):
|
||||||
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
|
def __init__(
|
||||||
|
self,
|
||||||
|
betas: tuple[float, float] = (0.9, 0.999),
|
||||||
|
eps: float = 1e-08,
|
||||||
|
weight_decay: float = 0,
|
||||||
|
):
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.betas = betas
|
self.betas = betas
|
||||||
@ -48,7 +58,14 @@ class OptimizerFactoryAdam(OptimizerFactory):
|
|||||||
|
|
||||||
|
|
||||||
class OptimizerFactoryRMSprop(OptimizerFactory):
|
class OptimizerFactoryRMSprop(OptimizerFactory):
|
||||||
def __init__(self, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
alpha: float = 0.99,
|
||||||
|
eps: float = 1e-08,
|
||||||
|
weight_decay: float = 0,
|
||||||
|
momentum: float = 0,
|
||||||
|
centered: bool = False,
|
||||||
|
):
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
self.centered = centered
|
self.centered = centered
|
||||||
|
@ -30,7 +30,7 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
|||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
device: TDevice,
|
device: TDevice,
|
||||||
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||||
target_entropy = -np.prod(envs.get_action_shape())
|
target_entropy = float(-np.prod(envs.get_action_shape()))
|
||||||
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||||
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
||||||
return target_entropy, log_alpha, alpha_optim
|
return target_entropy, log_alpha, alpha_optim
|
||||||
|
@ -19,21 +19,21 @@ class DistributionFunctionFactory(ToStringMixin, ABC):
|
|||||||
|
|
||||||
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
||||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||||
assert envs.get_type().assert_discrete(self)
|
envs.get_type().assert_discrete(self)
|
||||||
return self._dist_fn
|
return self._dist_fn
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dist_fn(p):
|
def _dist_fn(p: TDistParams) -> torch.distributions.Distribution:
|
||||||
return torch.distributions.Categorical(logits=p)
|
return torch.distributions.Categorical(logits=p)
|
||||||
|
|
||||||
|
|
||||||
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
||||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||||
assert envs.get_type().assert_continuous(self)
|
envs.get_type().assert_continuous(self)
|
||||||
return self._dist_fn
|
return self._dist_fn
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dist_fn(*p):
|
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
|
||||||
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ from dataclasses import asdict, dataclass
|
|||||||
from typing import Any, Literal, Protocol
|
from typing import Any, Literal, Protocol
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
@ -71,7 +72,7 @@ class ParamTransformerChangeValue(ParamTransformer):
|
|||||||
def __init__(self, key: str):
|
def __init__(self, key: str):
|
||||||
self.key = key
|
self.key = key
|
||||||
|
|
||||||
def transform(self, params: dict[str, Any], data: ParamTransformerData):
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
params[self.key] = self.change_value(params[self.key], data)
|
params[self.key] = self.change_value(params[self.key], data)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -118,6 +119,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
|
|||||||
)
|
)
|
||||||
if lr_scheduler_factory is not None:
|
if lr_scheduler_factory is not None:
|
||||||
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
||||||
|
lr_scheduler: LRScheduler | MultipleLRSchedulers | None
|
||||||
match len(lr_schedulers):
|
match len(lr_schedulers):
|
||||||
case 0:
|
case 0:
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
@ -140,6 +142,7 @@ class ParamTransformerActorAndCriticLRScheduler(ParamTransformer):
|
|||||||
self.key_scheduler = key_scheduler
|
self.key_scheduler = key_scheduler
|
||||||
|
|
||||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
assert data.actor is not None and data.critic1 is not None
|
||||||
transformer = ParamTransformerMultiLRScheduler(
|
transformer = ParamTransformerMultiLRScheduler(
|
||||||
[
|
[
|
||||||
(data.actor.optim, self.key_factory_actor),
|
(data.actor.optim, self.key_factory_actor),
|
||||||
@ -164,6 +167,7 @@ class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
|
|||||||
self.key_scheduler = key_scheduler
|
self.key_scheduler = key_scheduler
|
||||||
|
|
||||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
assert data.actor is not None and data.critic1 is not None and data.critic2 is not None
|
||||||
transformer = ParamTransformerMultiLRScheduler(
|
transformer = ParamTransformerMultiLRScheduler(
|
||||||
[
|
[
|
||||||
(data.actor.optim, self.key_factory_actor),
|
(data.actor.optim, self.key_factory_actor),
|
||||||
@ -247,7 +251,7 @@ class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
|||||||
lr: float = 1e-3
|
lr: float = 1e-3
|
||||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
def _get_param_transformers(self):
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
return [
|
return [
|
||||||
ParamTransformerDrop("lr"),
|
ParamTransformerDrop("lr"),
|
||||||
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
|
ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"),
|
||||||
@ -261,7 +265,7 @@ class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
|
|||||||
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
def _get_param_transformers(self):
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
return [
|
return [
|
||||||
ParamTransformerDrop("actor_lr", "critic_lr"),
|
ParamTransformerDrop("actor_lr", "critic_lr"),
|
||||||
ParamTransformerActorAndCriticLRScheduler(
|
ParamTransformerActorAndCriticLRScheduler(
|
||||||
@ -325,7 +329,7 @@ class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
|||||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
def _get_param_transformers(self):
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
return [
|
return [
|
||||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||||
ParamTransformerActorDualCriticsLRScheduler(
|
ParamTransformerActorDualCriticsLRScheduler(
|
||||||
@ -348,7 +352,7 @@ class SACParams(Params, ParamsMixinActorAndDualCritics):
|
|||||||
action_scaling: bool = True
|
action_scaling: bool = True
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
|
|
||||||
def _get_param_transformers(self):
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||||
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
||||||
@ -365,7 +369,7 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
|||||||
is_double: bool = True
|
is_double: bool = True
|
||||||
clip_loss_grad: bool = False
|
clip_loss_grad: bool = False
|
||||||
|
|
||||||
def _get_param_transformers(self):
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
||||||
return transformers
|
return transformers
|
||||||
@ -380,7 +384,7 @@ class DDPGParams(Params, ParamsMixinActorAndCritic):
|
|||||||
action_scaling: bool = True
|
action_scaling: bool = True
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
|
|
||||||
def _get_param_transformers(self):
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
|
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
|
||||||
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
||||||
@ -399,7 +403,7 @@ class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
|||||||
action_scaling: bool = True
|
action_scaling: bool = True
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
|
|
||||||
def _get_param_transformers(self):
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||||
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
||||||
|
@ -36,7 +36,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
|||||||
lr: float,
|
lr: float,
|
||||||
lr_scale: float,
|
lr_scale: float,
|
||||||
reward_scale: float,
|
reward_scale: float,
|
||||||
forward_loss_weight,
|
forward_loss_weight: float,
|
||||||
):
|
):
|
||||||
self.feature_net_factory = feature_net_factory
|
self.feature_net_factory = feature_net_factory
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
@ -54,6 +54,8 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
|||||||
) -> ICMPolicy:
|
) -> ICMPolicy:
|
||||||
feature_net = self.feature_net_factory.create_module(envs, device)
|
feature_net = self.feature_net_factory.create_module(envs, device)
|
||||||
action_dim = envs.get_action_shape()
|
action_dim = envs.get_action_shape()
|
||||||
|
if type(action_dim) != int:
|
||||||
|
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
|
||||||
feature_dim = feature_net.output_dim
|
feature_dim = feature_net.output_dim
|
||||||
icm_net = IntrinsicCuriosityModule(
|
icm_net = IntrinsicCuriosityModule(
|
||||||
feature_net.module,
|
feature_net.module,
|
||||||
|
@ -18,7 +18,7 @@ class TrainingContext:
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
|
||||||
class TrainerEpochCallback(ToStringMixin, ABC):
|
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
|
||||||
"""Callback which is called at the beginning of each epoch."""
|
"""Callback which is called at the beginning of each epoch."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -26,7 +26,21 @@ class TrainerEpochCallback(ToStringMixin, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
|
||||||
def fn(epoch, env_step):
|
def fn(epoch: int, env_step: int) -> None:
|
||||||
|
return self.callback(epoch, env_step, context)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerEpochCallbackTest(ToStringMixin, ABC):
|
||||||
|
"""Callback which is called at the beginning of each epoch."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]:
|
||||||
|
def fn(epoch: int, env_step: int | None) -> None:
|
||||||
return self.callback(epoch, env_step, context)
|
return self.callback(epoch, env_step, context)
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
@ -42,7 +56,7 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
|
def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
|
||||||
def fn(mean_rewards: float):
|
def fn(mean_rewards: float) -> bool:
|
||||||
return self.should_stop(mean_rewards, context)
|
return self.should_stop(mean_rewards, context)
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
@ -50,6 +64,6 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainerCallbacks:
|
class TrainerCallbacks:
|
||||||
epoch_callback_train: TrainerEpochCallback | None = None
|
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
||||||
epoch_callback_test: TrainerEpochCallback | None = None
|
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
||||||
stop_callback: TrainerStopCallback | None = None
|
stop_callback: TrainerStopCallback | None = None
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, TypeAlias, cast
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -17,7 +17,7 @@ from tianshou.policy import BasePolicy
|
|||||||
from tianshou.policy.base import TLearningRateScheduler
|
from tianshou.policy.base import TLearningRateScheduler
|
||||||
from tianshou.utils import RunningMeanStd
|
from tianshou.utils import RunningMeanStd
|
||||||
|
|
||||||
TDistParams = torch.Tensor | tuple[torch.Tensor]
|
TDistParams: TypeAlias = torch.Tensor | [torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class PGPolicy(BasePolicy):
|
class PGPolicy(BasePolicy):
|
||||||
|
@ -15,7 +15,7 @@ class MultipleLRSchedulers:
|
|||||||
policy = PPOPolicy(..., lr_scheduler=scheduler)
|
policy = PPOPolicy(..., lr_scheduler=scheduler)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR):
|
def __init__(self, *args: torch.optim.lr_scheduler.LRScheduler):
|
||||||
self.schedulers = args
|
self.schedulers = args
|
||||||
|
|
||||||
def step(self) -> None:
|
def step(self) -> None:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, no_type_check
|
from typing import Any, TypeAlias, no_type_check
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -11,6 +11,7 @@ from tianshou.data.types import RecurrentStateBatch
|
|||||||
|
|
||||||
ModuleType = type[nn.Module]
|
ModuleType = type[nn.Module]
|
||||||
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
||||||
|
TActionShape: TypeAlias = Sequence[int] | int
|
||||||
|
|
||||||
|
|
||||||
def miniblock(
|
def miniblock(
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.utils.net.common import MLP, BaseActor
|
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
|
||||||
|
|
||||||
SIGMA_MIN = -20
|
SIGMA_MIN = -20
|
||||||
SIGMA_MAX = 2
|
SIGMA_MAX = 2
|
||||||
@ -40,7 +40,7 @@ class Actor(BaseActor):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
preprocess_net: nn.Module,
|
preprocess_net: nn.Module,
|
||||||
action_shape: Sequence[int],
|
action_shape: TActionShape,
|
||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
max_action: float = 1.0,
|
max_action: float = 1.0,
|
||||||
device: str | int | torch.device = "cpu",
|
device: str | int | torch.device = "cpu",
|
||||||
@ -180,7 +180,7 @@ class ActorProb(BaseActor):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
preprocess_net: nn.Module,
|
preprocess_net: nn.Module,
|
||||||
action_shape: Sequence[int],
|
action_shape: TActionShape,
|
||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
max_action: float = 1.0,
|
max_action: float = 1.0,
|
||||||
device: str | int | torch.device = "cpu",
|
device: str | int | torch.device = "cpu",
|
||||||
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.data import Batch, to_torch
|
from tianshou.data import Batch, to_torch
|
||||||
from tianshou.utils.net.common import MLP, BaseActor
|
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
|
||||||
|
|
||||||
|
|
||||||
class Actor(BaseActor):
|
class Actor(BaseActor):
|
||||||
@ -39,7 +39,7 @@ class Actor(BaseActor):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
preprocess_net: nn.Module,
|
preprocess_net: nn.Module,
|
||||||
action_shape: Sequence[int],
|
action_shape: TActionShape,
|
||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
softmax_output: bool = True,
|
softmax_output: bool = True,
|
||||||
preprocess_net_output_dim: int | None = None,
|
preprocess_net_output_dim: int | None = None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user