Add high-level API support for TD3

* Created mixins for agent factories to reduce code duplication
 * Further factorised params & mixins for experiment factories
 * Additional parameter abstractions
 * Implement high-level MuJoCo TD3 example
This commit is contained in:
Dominik Jain 2023-09-26 15:35:18 +02:00
parent 6a739384ee
commit e993425aa1
14 changed files with 626 additions and 116 deletions

View File

@ -66,7 +66,7 @@ def main(
experiment = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
.with_ppo_params(
.with_params(
PPOParams(
discount_factor=gamma,
gae_lambda=gae_lambda,

View File

@ -7,13 +7,13 @@ from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
RLExperimentConfig,
SACExperimentBuilder,
)
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
from tianshou.highlevel.params.policy_params import SACParams
def main(
@ -70,7 +70,9 @@ def main(
),
)
.with_actor_factory_default(
hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True,
hidden_sizes,
continuous_unbounded=True,
continuous_conditioned_sigma=True,
)
.with_common_critic_factory_default(hidden_sizes)
.build()

View File

@ -0,0 +1,85 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
RLExperimentConfig,
TD3ExperimentBuilder,
)
from tianshou.highlevel.params.env_param import MaxActionScaledFloatEnvParamFactory
from tianshou.highlevel.params.noise import MaxActionScaledGaussianNoiseFactory
from tianshou.highlevel.params.policy_params import TD3Params
def main(
experiment_config: RLExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256),
actor_lr: float = 3e-4,
critic_lr: float = 3e-4,
gamma: float = 0.99,
tau: float = 0.005,
exploration_noise: float = 0.1,
policy_noise: float = 0.2,
noise_clip: float = 0.5,
update_actor_freq: int = 2,
start_timesteps: int = 25000,
epoch: int = 200,
step_per_epoch: int = 5000,
step_per_collect: int = 1,
update_per_step: int = 1,
n_step: int = 1,
batch_size: int = 256,
training_num: int = 1,
test_num: int = 10,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
batch_size=batch_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
start_timesteps=start_timesteps,
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
TD3ExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_td3_params(
TD3Params(
tau=tau,
gamma=gamma,
estimation_step=n_step,
update_actor_freq=update_actor_freq,
noise_clip=MaxActionScaledFloatEnvParamFactory(noise_clip),
policy_noise=MaxActionScaledFloatEnvParamFactory(policy_noise),
exploration_noise=MaxActionScaledGaussianNoiseFactory(exploration_noise),
actor_lr=actor_lr,
critic1_lr=critic_lr,
critic2_lr=critic_lr,
),
)
.with_actor_factory_default(hidden_sizes)
.with_common_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)

View File

@ -1,21 +1,35 @@
import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Dict, Any, List, Tuple
from typing import Any
import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.module import (
ActorCriticModuleOpt,
ActorFactory,
ActorModuleOptFactory,
CriticFactory,
CriticModuleOptFactory,
ModuleOpt,
TDevice,
)
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.highlevel.params.policy_params import (
ParamTransformer,
PPOParams,
SACParams,
TD3Params,
)
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils import MultipleLRSchedulers
@ -135,7 +149,7 @@ class ParamTransformerDrop(ParamTransformer):
def __init__(self, *keys: str):
self.keys = keys
def transform(self, kwargs: Dict[str, Any]) -> None:
def transform(self, kwargs: dict[str, Any]) -> None:
for k in self.keys:
del kwargs[k]
@ -144,12 +158,94 @@ class ParamTransformerLRScheduler(ParamTransformer):
def __init__(self, optim: torch.optim.Optimizer):
self.optim = optim
def transform(self, kwargs: Dict[str, Any]) -> None:
def transform(self, kwargs: dict[str, Any]) -> None:
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None
kwargs["lr_scheduler"] = (
factory.create_scheduler(self.optim) if factory is not None else None
)
class PPOAgentFactory(OnpolicyAgentFactory):
class _ActorMixin:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
class _ActorCriticMixin:
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
self.actor_factory = actor_factory
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.critic_use_action = critic_use_action
def create_actor_critic_module_opt(
self,
envs: Environments,
device: TDevice,
lr: float,
) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim)
class _ActorAndCriticMixin(_ActorMixin):
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
super().__init__(actor_factory, optim_factory)
self.critic_module_opt_factory = CriticModuleOptFactory(
critic_factory,
optim_factory,
critic_use_action,
)
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action)
self.critic2_module_opt_factory = CriticModuleOptFactory(
critic2_factory,
optim_factory,
critic_use_action,
)
def create_critic2_module_opt(
self,
envs: Environments,
device: TDevice,
lr: float,
) -> ModuleOpt:
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
def __init__(
self,
params: PPOParams,
@ -160,27 +256,29 @@ class PPOAgentFactory(OnpolicyAgentFactory):
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
):
super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory
self.actor_factory = actor_factory
self.config = params
_ActorCriticMixin.__init__(
self,
actor_factory,
critic_factory,
optimizer_factory,
critic_use_action=False,
)
self.params = params
self.dist_fn = dist_fn
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=False)
actor_critic = ActorCritic(actor, critic)
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
kwargs = self.config.create_kwargs(
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
kwargs = self.params.create_kwargs(
ParamTransformerDrop("lr"),
ParamTransformerLRScheduler(optim))
ParamTransformerLRScheduler(actor_critic.optim),
)
return PPOPolicy(
actor=actor,
critic=critic,
optim=optim,
actor=actor_critic.actor,
critic=actor_critic.critic,
optim=actor_critic.optim,
dist_fn=self.dist_fn,
action_space=envs.get_action_space(),
**kwargs
**kwargs,
)
@ -190,7 +288,7 @@ class ParamTransformerAlpha(ParamTransformer):
self.optim_factory = optim_factory
self.device = device
def transform(self, kwargs: Dict[str, Any]) -> None:
def transform(self, kwargs: dict[str, Any]) -> None:
key = "alpha"
alpha = self.get(kwargs, key)
if isinstance(alpha, AutoAlphaFactory):
@ -198,13 +296,17 @@ class ParamTransformerAlpha(ParamTransformer):
class ParamTransformerMultiLRScheduler(ParamTransformer):
def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]):
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]):
self.optim_key_list = optim_key_list
def transform(self, kwargs: Dict[str, Any]) -> None:
def transform(self, kwargs: dict[str, Any]) -> None:
lr_schedulers = []
for optim, lr_scheduler_factory_key in self.optim_key_list:
lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True)
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
kwargs,
lr_scheduler_factory_key,
drop=True,
)
if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
match len(lr_schedulers):
@ -217,7 +319,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
kwargs["lr_scheduler"] = lr_scheduler
class SACAgentFactory(OffpolicyAgentFactory):
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__(
self,
params: SACParams,
@ -228,35 +330,114 @@ class SACAgentFactory(OffpolicyAgentFactory):
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config)
self.critic2_factory = critic2_factory
self.critic1_factory = critic1_factory
self.actor_factory = actor_factory
self.optim_factory = optim_factory
_ActorAndDualCriticsMixin.__init__(
self,
actor_factory,
critic1_factory,
critic2_factory,
optim_factory,
critic_use_action=True,
)
self.params = params
self.optim_factory = optim_factory
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module(envs, device)
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr)
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr)
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr)
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
kwargs = self.params.create_kwargs(
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerMultiLRScheduler([
(actor_optim, "actor_lr_scheduler_factory"),
(critic1_optim, "critic1_lr_scheduler_factory"),
(critic2_optim, "critic2_lr_scheduler_factory")]
ParamTransformerMultiLRScheduler(
[
(actor.optim, "actor_lr_scheduler_factory"),
(critic1.optim, "critic1_lr_scheduler_factory"),
(critic2.optim, "critic2_lr_scheduler_factory"),
],
),
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device))
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device),
)
return SACPolicy(
actor=actor,
actor_optim=actor_optim,
critic=critic1,
critic_optim=critic1_optim,
critic2=critic2,
critic2_optim=critic2_optim,
actor=actor.module,
actor_optim=actor.optim,
critic=critic1.module,
critic_optim=critic1.optim,
critic2=critic2.module,
critic2_optim=critic2.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs
**kwargs,
)
class ParamTransformerNoiseFactory(ParamTransformer):
def __init__(self, key: str, envs: Environments):
self.key = key
self.envs = envs
def transform(self, kwargs: dict[str, Any]) -> None:
value = kwargs[self.key]
if isinstance(value, NoiseFactory):
kwargs[self.key] = value.create_noise(self.envs)
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
def __init__(self, key: str, envs: Environments):
self.key = key
self.envs = envs
def transform(self, kwargs: dict[str, Any]) -> None:
value = kwargs[self.key]
if isinstance(value, FloatEnvParamFactory):
kwargs[self.key] = value.create_param(self.envs)
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__(
self,
params: TD3Params,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config)
_ActorAndDualCriticsMixin.__init__(
self,
actor_factory,
critic1_factory,
critic2_factory,
optim_factory,
critic_use_action=True,
)
self.params = params
self.optim_factory = optim_factory
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
kwargs = self.params.create_kwargs(
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerMultiLRScheduler(
[
(actor.optim, "actor_lr_scheduler_factory"),
(critic1.optim, "critic1_lr_scheduler_factory"),
(critic2.optim, "critic2_lr_scheduler_factory"),
],
),
ParamTransformerNoiseFactory("exploration_noise", envs),
ParamTransformerFloatEnvParamFactory("policy_noise", envs),
ParamTransformerFloatEnvParamFactory("noise_clip", envs),
)
return TD3Policy(
actor=actor.module,
actor_optim=actor.optim,
critic=critic1.module,
critic_optim=critic1.optim,
critic2=critic2.module,
critic2_optim=critic2.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
)

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
class RLSamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are reasonable defaults?
num_epochs: int = 100
step_per_epoch: int = 30000
batch_size: int = 64

View File

@ -20,6 +20,14 @@ class EnvType(Enum):
def is_continuous(self):
return self == EnvType.CONTINUOUS
def assert_continuous(self, requiring_entity: Any):
if not self.is_continuous():
raise AssertionError(f"{requiring_entity} requires continuous environments")
def assert_discrete(self, requiring_entity: Any):
if not self.is_discrete():
raise AssertionError(f"{requiring_entity} requires discrete environments")
class Environments(ABC):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
@ -28,7 +36,10 @@ class Environments(ABC):
self.test_envs = test_envs
def info(self) -> dict[str, Any]:
return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()}
return {
"action_shape": self.get_action_shape(),
"state_shape": self.get_observation_shape(),
}
@abstractmethod
def get_action_shape(self) -> TShape:
@ -81,7 +92,7 @@ class ContinuousEnvironments(Environments):
def get_observation_shape(self) -> TShape:
return self.state_shape
def get_type(self):
def get_type(self) -> EnvType:
return EnvType.CONTINUOUS

View File

@ -1,25 +1,31 @@
from abc import abstractmethod
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from pprint import pprint
from typing import Generic, TypeVar, Callable
from typing import Generic, Self, TypeVar
import numpy as np
import torch
from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, SACAgentFactory
from tianshou.highlevel.agent import (
AgentFactory,
PPOAgentFactory,
SACAgentFactory,
TD3AgentFactory,
)
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module import (
ActorFactory,
ContinuousActorType,
CriticFactory,
DefaultActorFactory,
DefaultCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
from tianshou.highlevel.params.policy_params import PPOParams, SACParams
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
from tianshou.policy import BasePolicy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer
@ -150,7 +156,10 @@ class RLExperimentBuilder:
return self
def with_optim_factory_default(
self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0,
self: TBuilder,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0,
) -> TBuilder:
"""Configures the use of the default optimizer, Adam, with the given parameters.
@ -174,12 +183,16 @@ class RLExperimentBuilder:
def build(self) -> RLExperiment:
return RLExperiment(
self._config, self._env_factory, self._create_agent_factory(), self._logger_factory,
self._config,
self._env_factory,
self._create_agent_factory(),
self._logger_factory,
)
class _BuilderMixinActorFactory:
def __init__(self):
def __init__(self, continuous_actor_type: ContinuousActorType):
self._continuous_actor_type = continuous_actor_type
self._actor_factory: ActorFactory | None = None
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
@ -187,7 +200,7 @@ class _BuilderMixinActorFactory:
self._actor_factory = actor_factory
return self
def with_actor_factory_default(
def _with_actor_factory_default(
self: TBuilder,
hidden_sizes: Sequence[int],
continuous_unbounded=False,
@ -195,6 +208,7 @@ class _BuilderMixinActorFactory:
) -> TBuilder:
self: TBuilder | _BuilderMixinActorFactory
self._actor_factory = DefaultActorFactory(
self._continuous_actor_type,
hidden_sizes,
continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma,
@ -203,11 +217,40 @@ class _BuilderMixinActorFactory:
def _get_actor_factory(self):
if self._actor_factory is None:
return DefaultActorFactory()
return DefaultActorFactory(self._continuous_actor_type)
else:
return self._actor_factory
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self):
super().__init__(ContinuousActorType.GAUSSIAN)
def with_actor_factory_default(
self,
hidden_sizes: Sequence[int],
continuous_unbounded=False,
continuous_conditioned_sigma=False,
) -> Self:
return super()._with_actor_factory_default(
hidden_sizes,
continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma,
)
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self):
super().__init__(ContinuousActorType.DETERMINISTIC)
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
return super()._with_actor_factory_default(hidden_sizes)
class _BuilderMixinCriticsFactory:
def __init__(self, num_critics: int):
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
@ -238,7 +281,8 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
return self
def with_critic_factory_default(
self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
self: TBuilder,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinSingleCriticFactory"
self._with_critic_factory_default(0, hidden_sizes)
@ -256,7 +300,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self
def with_common_critic_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
self,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
for i in range(len(self._critic_factories)):
@ -269,7 +314,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self
def with_critic1_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
self,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes)
@ -281,7 +327,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self
def with_critic2_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
self,
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes)
@ -289,7 +336,9 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
class PPOExperimentBuilder(
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory,
RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinSingleCriticFactory,
):
def __init__(
self,
@ -299,12 +348,12 @@ class PPOExperimentBuilder(
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOParams = PPOParams()
self._dist_fn = dist_fn
def with_ppo_params(self, params: PPOParams) -> "PPOExperimentBuilder":
def with_ppo_params(self, params: PPOParams) -> Self:
self._params = params
return self
@ -316,12 +365,14 @@ class PPOExperimentBuilder(
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
self._dist_fn
self._dist_fn,
)
class SACExperimentBuilder(
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory,
RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinDualCriticFactory,
):
def __init__(
self,
@ -330,14 +381,51 @@ class SACExperimentBuilder(
sampling_config: RLSamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
self._params: SACParams = SACParams()
def with_sac_params(self, params: SACParams) -> "SACExperimentBuilder":
def with_sac_params(self, params: SACParams) -> Self:
self._params = params
return self
def _create_agent_factory(self) -> AgentFactory:
return SACAgentFactory(self._params, self._sampling_config, self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory())
return SACAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_critic_factory(1),
self._get_optim_factory(),
)
class TD3ExperimentBuilder(
RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic,
_BuilderMixinDualCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
self._params: TD3Params = TD3Params()
def with_td3_params(self, params: TD3Params) -> Self:
self._params = params
return self
def _create_agent_factory(self) -> AgentFactory:
return TD3AgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_critic_factory(1),
self._get_optim_factory(),
)

View File

@ -1,13 +1,13 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Literal
from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger, WandbLogger
TLogger = TensorboardLogger | WandbLogger
TLogger: TypeAlias = TensorboardLogger | WandbLogger
@dataclass
@ -30,7 +30,7 @@ class DefaultLoggerFactory(LoggerFactory):
wandb_project: str | None = None,
):
if logger_type == "wandb" and wandb_project is None:
raise ValueError("Must provide 'wand_project'")
raise ValueError("Must provide 'wandb_project'")
self.log_dir = log_dir
self.logger_type = logger_type
self.wandb_project = wandb_project

View File

@ -1,16 +1,18 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.continuous import Critic as ContinuousCritic
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous
from tianshou.utils.net.common import ActorCritic, Net
TDevice = str | int | torch.device
TDevice: TypeAlias = str | int | torch.device
def init_linear_orthogonal(module: torch.nn.Module):
@ -24,6 +26,11 @@ def init_linear_orthogonal(module: torch.nn.Module):
torch.nn.init.zeros_(m.bias)
class ContinuousActorType:
GAUSSIAN = "gaussian"
DETERMINISTIC = "deterministic"
class ActorFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
@ -47,30 +54,36 @@ class ActorFactory(ABC):
class DefaultActorFactory(ActorFactory):
"""An actor factory which, depending on the type of environment, creates a suitable MLP-based policy."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(
self,
continuous_actor_type: ContinuousActorType,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded=False,
continuous_conditioned_sigma=False,
):
self.continuous_actor_type = continuous_actor_type
self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes
"""
An actor factory which, depending on the type of environment, creates a suitable MLP-based policy
"""
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = ContinuousActorProbFactory(
self.hidden_sizes,
unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma,
)
match self.continuous_actor_type:
case ContinuousActorType.GAUSSIAN:
factory = ContinuousActorFactoryGaussian(
self.hidden_sizes,
unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma,
)
case ContinuousActorType.DETERMINISTIC:
factory = ContinuousActorFactoryDeterministic(self.hidden_sizes)
case _:
raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
@ -82,8 +95,25 @@ class ContinuousActorFactory(ActorFactory, ABC):
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
class ContinuousActorFactoryDeterministic(ContinuousActorFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
class ContinuousActorProbFactory(ContinuousActorFactory):
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net(
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
device=device,
)
return continuous.Actor(
net_a,
envs.get_action_shape(),
hidden_sizes=(),
device=device,
).to(device)
class ContinuousActorFactoryGaussian(ContinuousActorFactory):
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
self.hidden_sizes = hidden_sizes
self.unbounded = unbounded
@ -96,7 +126,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
activation=nn.Tanh,
device=device,
)
actor = ActorProb(
actor = continuous.ActorProb(
net_a,
envs.get_action_shape(),
unbounded=self.unbounded,
@ -155,6 +185,54 @@ class ContinuousNetCriticFactory(ContinuousCriticFactory):
activation=nn.Tanh,
device=device,
)
critic = ContinuousCritic(net_c, device=device).to(device)
critic = continuous.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic
@dataclass
class ModuleOpt:
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer
@property
def actor(self):
return self.actor_critic_module.actor
@property
def critic(self):
return self.actor_critic_module.critic
class ActorModuleOptFactory:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)
class CriticModuleOptFactory:
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)

View File

@ -1,13 +1,9 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any
import torch
from torch import Tensor
from torch.optim import Adam
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
class OptimizerFactory(ABC):
@abstractmethod
@ -38,5 +34,3 @@ class AdamOptimizerFactory(OptimizerFactory):
eps=self.eps,
weight_decay=self.weight_decay,
)

View File

@ -0,0 +1,24 @@
"""Factories for the generation of environment-dependent parameters."""
from abc import ABC, abstractmethod
from typing import TypeVar
from tianshou.highlevel.env import ContinuousEnvironments, Environments
T = TypeVar("T")
class FloatEnvParamFactory(ABC):
@abstractmethod
def create_param(self, envs: Environments) -> float:
pass
class MaxActionScaledFloatEnvParamFactory(FloatEnvParamFactory):
def __init__(self, value: float):
""":param value: value with which to scale the max action value"""
self.value = value
def create_param(self, envs: Environments) -> float:
envs.get_type().assert_continuous(self)
envs: ContinuousEnvironments
return envs.max_action * self.value

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
import numpy as np
import torch
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import RLSamplingConfig

View File

@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.highlevel.env import ContinuousEnvironments, Environments
class NoiseFactory(ABC):
@abstractmethod
def create_noise(self, envs: Environments) -> BaseNoise:
pass
class MaxActionScaledGaussianNoiseFactory(NoiseFactory):
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
This factory can only be applied to continuous action spaces.
"""
def __init__(self, std_fraction: float):
self.std_fraction = std_fraction
def create_noise(self, envs: Environments) -> BaseNoise:
envs.get_type().assert_continuous(self)
envs: ContinuousEnvironments
return GaussianNoise(sigma=envs.max_action * self.std_fraction)

View File

@ -1,21 +1,23 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from typing import Dict, Any, Literal
from dataclasses import asdict, dataclass
from typing import Any, Literal
import torch
from tianshou.exploration import BaseNoise
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
class ParamTransformer(ABC):
@abstractmethod
def transform(self, kwargs: Dict[str, Any]) -> None:
def transform(self, kwargs: dict[str, Any]) -> None:
pass
@staticmethod
def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any:
def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
value = d[key]
if drop:
del d[key]
@ -24,7 +26,7 @@ class ParamTransformer(ABC):
@dataclass
class Params:
def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]:
def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]:
d = asdict(self)
for transformer in transformers:
transformer.transform(d)
@ -34,6 +36,7 @@ class Params:
@dataclass
class PGParams(Params):
"""Config of general policy-gradient algorithms."""
discount_factor: float = 0.99
reward_normalization: bool = False
deterministic_eval: bool = False
@ -53,6 +56,7 @@ class A2CParams(PGParams):
@dataclass
class PPOParams(A2CParams):
"""PPO specific config."""
eps_clip: float = 0.2
dual_clip: float | None = None
value_clip: bool = False
@ -63,7 +67,17 @@ class PPOParams(A2CParams):
@dataclass
class SACParams(Params):
class ActorAndDualCriticsParams(Params):
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
@dataclass
class SACParams(ActorAndDualCriticsParams):
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
@ -72,9 +86,16 @@ class SACParams(Params):
deterministic_eval: bool = True
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
@dataclass
class TD3Params(ActorAndDualCriticsParams):
tau: float = 0.005
gamma: float = 0.99
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
policy_noise: float | FloatEnvParamFactory = 0.2
noise_clip: float | FloatEnvParamFactory = 0.5
update_actor_freq: int = 2
estimation_step: int = 1
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"