Add high-level API support for TD3
* Created mixins for agent factories to reduce code duplication * Further factorised params & mixins for experiment factories * Additional parameter abstractions * Implement high-level MuJoCo TD3 example
This commit is contained in:
parent
6a739384ee
commit
e993425aa1
@ -66,7 +66,7 @@ def main(
|
|||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
|
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
|
||||||
.with_ppo_params(
|
.with_params(
|
||||||
PPOParams(
|
PPOParams(
|
||||||
discount_factor=gamma,
|
discount_factor=gamma,
|
||||||
gae_lambda=gae_lambda,
|
gae_lambda=gae_lambda,
|
||||||
|
|||||||
@ -7,13 +7,13 @@ from collections.abc import Sequence
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.params.policy_params import SACParams
|
|
||||||
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
|
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
SACExperimentBuilder,
|
SACExperimentBuilder,
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
|
||||||
|
from tianshou.highlevel.params.policy_params import SACParams
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -70,7 +70,9 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory_default(
|
.with_actor_factory_default(
|
||||||
hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True,
|
hidden_sizes,
|
||||||
|
continuous_unbounded=True,
|
||||||
|
continuous_conditioned_sigma=True,
|
||||||
)
|
)
|
||||||
.with_common_critic_factory_default(hidden_sizes)
|
.with_common_critic_factory_default(hidden_sizes)
|
||||||
.build()
|
.build()
|
||||||
|
|||||||
85
examples/mujoco/mujoco_td3_hl.py
Normal file
85
examples/mujoco/mujoco_td3_hl.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
RLExperimentConfig,
|
||||||
|
TD3ExperimentBuilder,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.env_param import MaxActionScaledFloatEnvParamFactory
|
||||||
|
from tianshou.highlevel.params.noise import MaxActionScaledGaussianNoiseFactory
|
||||||
|
from tianshou.highlevel.params.policy_params import TD3Params
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
task: str = "Ant-v3",
|
||||||
|
buffer_size: int = 1000000,
|
||||||
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
|
actor_lr: float = 3e-4,
|
||||||
|
critic_lr: float = 3e-4,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
tau: float = 0.005,
|
||||||
|
exploration_noise: float = 0.1,
|
||||||
|
policy_noise: float = 0.2,
|
||||||
|
noise_clip: float = 0.5,
|
||||||
|
update_actor_freq: int = 2,
|
||||||
|
start_timesteps: int = 25000,
|
||||||
|
epoch: int = 200,
|
||||||
|
step_per_epoch: int = 5000,
|
||||||
|
step_per_collect: int = 1,
|
||||||
|
update_per_step: int = 1,
|
||||||
|
n_step: int = 1,
|
||||||
|
batch_size: int = 256,
|
||||||
|
training_num: int = 1,
|
||||||
|
test_num: int = 10,
|
||||||
|
):
|
||||||
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
|
||||||
|
|
||||||
|
sampling_config = RLSamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
batch_size=batch_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
update_per_step=update_per_step,
|
||||||
|
start_timesteps=start_timesteps,
|
||||||
|
start_timesteps_random=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
|
experiment = (
|
||||||
|
TD3ExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||||
|
.with_td3_params(
|
||||||
|
TD3Params(
|
||||||
|
tau=tau,
|
||||||
|
gamma=gamma,
|
||||||
|
estimation_step=n_step,
|
||||||
|
update_actor_freq=update_actor_freq,
|
||||||
|
noise_clip=MaxActionScaledFloatEnvParamFactory(noise_clip),
|
||||||
|
policy_noise=MaxActionScaledFloatEnvParamFactory(policy_noise),
|
||||||
|
exploration_noise=MaxActionScaledGaussianNoiseFactory(exploration_noise),
|
||||||
|
actor_lr=actor_lr,
|
||||||
|
critic1_lr=critic_lr,
|
||||||
|
critic2_lr=critic_lr,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_actor_factory_default(hidden_sizes)
|
||||||
|
.with_common_critic_factory_default(hidden_sizes)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
CLI(main)
|
||||||
@ -1,21 +1,35 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Dict, Any, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.exploration import BaseNoise
|
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import Logger
|
||||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
from tianshou.highlevel.module import (
|
||||||
|
ActorCriticModuleOpt,
|
||||||
|
ActorFactory,
|
||||||
|
ActorModuleOptFactory,
|
||||||
|
CriticFactory,
|
||||||
|
CriticModuleOptFactory,
|
||||||
|
ModuleOpt,
|
||||||
|
TDevice,
|
||||||
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||||
|
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams
|
from tianshou.highlevel.params.noise import NoiseFactory
|
||||||
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
|
from tianshou.highlevel.params.policy_params import (
|
||||||
|
ParamTransformer,
|
||||||
|
PPOParams,
|
||||||
|
SACParams,
|
||||||
|
TD3Params,
|
||||||
|
)
|
||||||
|
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
||||||
from tianshou.policy.modelfree.pg import TDistParams
|
from tianshou.policy.modelfree.pg import TDistParams
|
||||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||||
from tianshou.utils import MultipleLRSchedulers
|
from tianshou.utils import MultipleLRSchedulers
|
||||||
@ -135,7 +149,7 @@ class ParamTransformerDrop(ParamTransformer):
|
|||||||
def __init__(self, *keys: str):
|
def __init__(self, *keys: str):
|
||||||
self.keys = keys
|
self.keys = keys
|
||||||
|
|
||||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||||
for k in self.keys:
|
for k in self.keys:
|
||||||
del kwargs[k]
|
del kwargs[k]
|
||||||
|
|
||||||
@ -144,12 +158,94 @@ class ParamTransformerLRScheduler(ParamTransformer):
|
|||||||
def __init__(self, optim: torch.optim.Optimizer):
|
def __init__(self, optim: torch.optim.Optimizer):
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
|
|
||||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||||
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
|
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
|
||||||
kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None
|
kwargs["lr_scheduler"] = (
|
||||||
|
factory.create_scheduler(self.optim) if factory is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
class _ActorMixin:
|
||||||
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||||
|
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
|
||||||
|
|
||||||
|
def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||||
|
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
|
||||||
|
|
||||||
|
|
||||||
|
class _ActorCriticMixin:
|
||||||
|
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
critic_use_action: bool,
|
||||||
|
):
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.critic_factory = critic_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
self.critic_use_action = critic_use_action
|
||||||
|
|
||||||
|
def create_actor_critic_module_opt(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
lr: float,
|
||||||
|
) -> ActorCriticModuleOpt:
|
||||||
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
|
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
|
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||||
|
return ActorCriticModuleOpt(actor_critic, optim)
|
||||||
|
|
||||||
|
|
||||||
|
class _ActorAndCriticMixin(_ActorMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
critic_use_action: bool,
|
||||||
|
):
|
||||||
|
super().__init__(actor_factory, optim_factory)
|
||||||
|
self.critic_module_opt_factory = CriticModuleOptFactory(
|
||||||
|
critic_factory,
|
||||||
|
optim_factory,
|
||||||
|
critic_use_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||||
|
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
||||||
|
|
||||||
|
|
||||||
|
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
critic2_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
critic_use_action: bool,
|
||||||
|
):
|
||||||
|
super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action)
|
||||||
|
self.critic2_module_opt_factory = CriticModuleOptFactory(
|
||||||
|
critic2_factory,
|
||||||
|
optim_factory,
|
||||||
|
critic_use_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_critic2_module_opt(
|
||||||
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
device: TDevice,
|
||||||
|
lr: float,
|
||||||
|
) -> ModuleOpt:
|
||||||
|
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
||||||
|
|
||||||
|
|
||||||
|
class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: PPOParams,
|
params: PPOParams,
|
||||||
@ -160,27 +256,29 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
|||||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config)
|
super().__init__(sampling_config)
|
||||||
self.optimizer_factory = optimizer_factory
|
_ActorCriticMixin.__init__(
|
||||||
self.critic_factory = critic_factory
|
self,
|
||||||
self.actor_factory = actor_factory
|
actor_factory,
|
||||||
self.config = params
|
critic_factory,
|
||||||
|
optimizer_factory,
|
||||||
|
critic_use_action=False,
|
||||||
|
)
|
||||||
|
self.params = params
|
||||||
self.dist_fn = dist_fn
|
self.dist_fn = dist_fn
|
||||||
|
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
critic = self.critic_factory.create_module(envs, device, use_action=False)
|
kwargs = self.params.create_kwargs(
|
||||||
actor_critic = ActorCritic(actor, critic)
|
|
||||||
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
|
|
||||||
kwargs = self.config.create_kwargs(
|
|
||||||
ParamTransformerDrop("lr"),
|
ParamTransformerDrop("lr"),
|
||||||
ParamTransformerLRScheduler(optim))
|
ParamTransformerLRScheduler(actor_critic.optim),
|
||||||
|
)
|
||||||
return PPOPolicy(
|
return PPOPolicy(
|
||||||
actor=actor,
|
actor=actor_critic.actor,
|
||||||
critic=critic,
|
critic=actor_critic.critic,
|
||||||
optim=optim,
|
optim=actor_critic.optim,
|
||||||
dist_fn=self.dist_fn,
|
dist_fn=self.dist_fn,
|
||||||
action_space=envs.get_action_space(),
|
action_space=envs.get_action_space(),
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -190,7 +288,7 @@ class ParamTransformerAlpha(ParamTransformer):
|
|||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||||
key = "alpha"
|
key = "alpha"
|
||||||
alpha = self.get(kwargs, key)
|
alpha = self.get(kwargs, key)
|
||||||
if isinstance(alpha, AutoAlphaFactory):
|
if isinstance(alpha, AutoAlphaFactory):
|
||||||
@ -198,13 +296,17 @@ class ParamTransformerAlpha(ParamTransformer):
|
|||||||
|
|
||||||
|
|
||||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||||
def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]):
|
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]):
|
||||||
self.optim_key_list = optim_key_list
|
self.optim_key_list = optim_key_list
|
||||||
|
|
||||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||||
lr_schedulers = []
|
lr_schedulers = []
|
||||||
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
||||||
lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True)
|
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
|
||||||
|
kwargs,
|
||||||
|
lr_scheduler_factory_key,
|
||||||
|
drop=True,
|
||||||
|
)
|
||||||
if lr_scheduler_factory is not None:
|
if lr_scheduler_factory is not None:
|
||||||
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
||||||
match len(lr_schedulers):
|
match len(lr_schedulers):
|
||||||
@ -217,7 +319,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
|
|||||||
kwargs["lr_scheduler"] = lr_scheduler
|
kwargs["lr_scheduler"] = lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class SACAgentFactory(OffpolicyAgentFactory):
|
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: SACParams,
|
params: SACParams,
|
||||||
@ -228,35 +330,114 @@ class SACAgentFactory(OffpolicyAgentFactory):
|
|||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config)
|
super().__init__(sampling_config)
|
||||||
self.critic2_factory = critic2_factory
|
_ActorAndDualCriticsMixin.__init__(
|
||||||
self.critic1_factory = critic1_factory
|
self,
|
||||||
self.actor_factory = actor_factory
|
actor_factory,
|
||||||
self.optim_factory = optim_factory
|
critic1_factory,
|
||||||
|
critic2_factory,
|
||||||
|
optim_factory,
|
||||||
|
critic_use_action=True,
|
||||||
|
)
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||||
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||||
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||||
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr)
|
|
||||||
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr)
|
|
||||||
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr)
|
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||||
ParamTransformerMultiLRScheduler([
|
ParamTransformerMultiLRScheduler(
|
||||||
(actor_optim, "actor_lr_scheduler_factory"),
|
[
|
||||||
(critic1_optim, "critic1_lr_scheduler_factory"),
|
(actor.optim, "actor_lr_scheduler_factory"),
|
||||||
(critic2_optim, "critic2_lr_scheduler_factory")]
|
(critic1.optim, "critic1_lr_scheduler_factory"),
|
||||||
|
(critic2.optim, "critic2_lr_scheduler_factory"),
|
||||||
|
],
|
||||||
),
|
),
|
||||||
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device))
|
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device),
|
||||||
|
)
|
||||||
return SACPolicy(
|
return SACPolicy(
|
||||||
actor=actor,
|
actor=actor.module,
|
||||||
actor_optim=actor_optim,
|
actor_optim=actor.optim,
|
||||||
critic=critic1,
|
critic=critic1.module,
|
||||||
critic_optim=critic1_optim,
|
critic_optim=critic1.optim,
|
||||||
critic2=critic2,
|
critic2=critic2.module,
|
||||||
critic2_optim=critic2_optim,
|
critic2_optim=critic2.optim,
|
||||||
action_space=envs.get_action_space(),
|
action_space=envs.get_action_space(),
|
||||||
observation_space=envs.get_observation_space(),
|
observation_space=envs.get_observation_space(),
|
||||||
**kwargs
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerNoiseFactory(ParamTransformer):
|
||||||
|
def __init__(self, key: str, envs: Environments):
|
||||||
|
self.key = key
|
||||||
|
self.envs = envs
|
||||||
|
|
||||||
|
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||||
|
value = kwargs[self.key]
|
||||||
|
if isinstance(value, NoiseFactory):
|
||||||
|
kwargs[self.key] = value.create_noise(self.envs)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
||||||
|
def __init__(self, key: str, envs: Environments):
|
||||||
|
self.key = key
|
||||||
|
self.envs = envs
|
||||||
|
|
||||||
|
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||||
|
value = kwargs[self.key]
|
||||||
|
if isinstance(value, FloatEnvParamFactory):
|
||||||
|
kwargs[self.key] = value.create_param(self.envs)
|
||||||
|
|
||||||
|
|
||||||
|
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: TD3Params,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic1_factory: CriticFactory,
|
||||||
|
critic2_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
):
|
||||||
|
super().__init__(sampling_config)
|
||||||
|
_ActorAndDualCriticsMixin.__init__(
|
||||||
|
self,
|
||||||
|
actor_factory,
|
||||||
|
critic1_factory,
|
||||||
|
critic2_factory,
|
||||||
|
optim_factory,
|
||||||
|
critic_use_action=True,
|
||||||
|
)
|
||||||
|
self.params = params
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||||
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||||
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||||
|
kwargs = self.params.create_kwargs(
|
||||||
|
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||||
|
ParamTransformerMultiLRScheduler(
|
||||||
|
[
|
||||||
|
(actor.optim, "actor_lr_scheduler_factory"),
|
||||||
|
(critic1.optim, "critic1_lr_scheduler_factory"),
|
||||||
|
(critic2.optim, "critic2_lr_scheduler_factory"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ParamTransformerNoiseFactory("exploration_noise", envs),
|
||||||
|
ParamTransformerFloatEnvParamFactory("policy_noise", envs),
|
||||||
|
ParamTransformerFloatEnvParamFactory("noise_clip", envs),
|
||||||
|
)
|
||||||
|
return TD3Policy(
|
||||||
|
actor=actor.module,
|
||||||
|
actor_optim=actor.optim,
|
||||||
|
critic=critic1.module,
|
||||||
|
critic_optim=critic1.optim,
|
||||||
|
critic2=critic2.module,
|
||||||
|
critic2_optim=critic2.optim,
|
||||||
|
action_space=envs.get_action_space(),
|
||||||
|
observation_space=envs.get_observation_space(),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||||||
class RLSamplingConfig:
|
class RLSamplingConfig:
|
||||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||||
|
|
||||||
|
# TODO: What are reasonable defaults?
|
||||||
num_epochs: int = 100
|
num_epochs: int = 100
|
||||||
step_per_epoch: int = 30000
|
step_per_epoch: int = 30000
|
||||||
batch_size: int = 64
|
batch_size: int = 64
|
||||||
|
|||||||
@ -20,6 +20,14 @@ class EnvType(Enum):
|
|||||||
def is_continuous(self):
|
def is_continuous(self):
|
||||||
return self == EnvType.CONTINUOUS
|
return self == EnvType.CONTINUOUS
|
||||||
|
|
||||||
|
def assert_continuous(self, requiring_entity: Any):
|
||||||
|
if not self.is_continuous():
|
||||||
|
raise AssertionError(f"{requiring_entity} requires continuous environments")
|
||||||
|
|
||||||
|
def assert_discrete(self, requiring_entity: Any):
|
||||||
|
if not self.is_discrete():
|
||||||
|
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||||
|
|
||||||
|
|
||||||
class Environments(ABC):
|
class Environments(ABC):
|
||||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
@ -28,7 +36,10 @@ class Environments(ABC):
|
|||||||
self.test_envs = test_envs
|
self.test_envs = test_envs
|
||||||
|
|
||||||
def info(self) -> dict[str, Any]:
|
def info(self) -> dict[str, Any]:
|
||||||
return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()}
|
return {
|
||||||
|
"action_shape": self.get_action_shape(),
|
||||||
|
"state_shape": self.get_observation_shape(),
|
||||||
|
}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_action_shape(self) -> TShape:
|
def get_action_shape(self) -> TShape:
|
||||||
@ -81,7 +92,7 @@ class ContinuousEnvironments(Environments):
|
|||||||
def get_observation_shape(self) -> TShape:
|
def get_observation_shape(self) -> TShape:
|
||||||
return self.state_shape
|
return self.state_shape
|
||||||
|
|
||||||
def get_type(self):
|
def get_type(self) -> EnvType:
|
||||||
return EnvType.CONTINUOUS
|
return EnvType.CONTINUOUS
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,25 +1,31 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from typing import Generic, TypeVar, Callable
|
from typing import Generic, Self, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, SACAgentFactory
|
from tianshou.highlevel.agent import (
|
||||||
|
AgentFactory,
|
||||||
|
PPOAgentFactory,
|
||||||
|
SACAgentFactory,
|
||||||
|
TD3AgentFactory,
|
||||||
|
)
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory
|
from tianshou.highlevel.env import EnvFactory
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
|
ContinuousActorType,
|
||||||
CriticFactory,
|
CriticFactory,
|
||||||
DefaultActorFactory,
|
DefaultActorFactory,
|
||||||
DefaultCriticFactory,
|
DefaultCriticFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams, SACParams
|
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.policy.modelfree.pg import TDistParams
|
from tianshou.policy.modelfree.pg import TDistParams
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
@ -150,7 +156,10 @@ class RLExperimentBuilder:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def with_optim_factory_default(
|
def with_optim_factory_default(
|
||||||
self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0,
|
self: TBuilder,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-08,
|
||||||
|
weight_decay=0,
|
||||||
) -> TBuilder:
|
) -> TBuilder:
|
||||||
"""Configures the use of the default optimizer, Adam, with the given parameters.
|
"""Configures the use of the default optimizer, Adam, with the given parameters.
|
||||||
|
|
||||||
@ -174,12 +183,16 @@ class RLExperimentBuilder:
|
|||||||
|
|
||||||
def build(self) -> RLExperiment:
|
def build(self) -> RLExperiment:
|
||||||
return RLExperiment(
|
return RLExperiment(
|
||||||
self._config, self._env_factory, self._create_agent_factory(), self._logger_factory,
|
self._config,
|
||||||
|
self._env_factory,
|
||||||
|
self._create_agent_factory(),
|
||||||
|
self._logger_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinActorFactory:
|
class _BuilderMixinActorFactory:
|
||||||
def __init__(self):
|
def __init__(self, continuous_actor_type: ContinuousActorType):
|
||||||
|
self._continuous_actor_type = continuous_actor_type
|
||||||
self._actor_factory: ActorFactory | None = None
|
self._actor_factory: ActorFactory | None = None
|
||||||
|
|
||||||
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
|
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
|
||||||
@ -187,7 +200,7 @@ class _BuilderMixinActorFactory:
|
|||||||
self._actor_factory = actor_factory
|
self._actor_factory = actor_factory
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_actor_factory_default(
|
def _with_actor_factory_default(
|
||||||
self: TBuilder,
|
self: TBuilder,
|
||||||
hidden_sizes: Sequence[int],
|
hidden_sizes: Sequence[int],
|
||||||
continuous_unbounded=False,
|
continuous_unbounded=False,
|
||||||
@ -195,6 +208,7 @@ class _BuilderMixinActorFactory:
|
|||||||
) -> TBuilder:
|
) -> TBuilder:
|
||||||
self: TBuilder | _BuilderMixinActorFactory
|
self: TBuilder | _BuilderMixinActorFactory
|
||||||
self._actor_factory = DefaultActorFactory(
|
self._actor_factory = DefaultActorFactory(
|
||||||
|
self._continuous_actor_type,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
continuous_unbounded=continuous_unbounded,
|
continuous_unbounded=continuous_unbounded,
|
||||||
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
||||||
@ -203,11 +217,40 @@ class _BuilderMixinActorFactory:
|
|||||||
|
|
||||||
def _get_actor_factory(self):
|
def _get_actor_factory(self):
|
||||||
if self._actor_factory is None:
|
if self._actor_factory is None:
|
||||||
return DefaultActorFactory()
|
return DefaultActorFactory(self._continuous_actor_type)
|
||||||
else:
|
else:
|
||||||
return self._actor_factory
|
return self._actor_factory
|
||||||
|
|
||||||
|
|
||||||
|
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||||
|
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||||
|
|
||||||
|
def with_actor_factory_default(
|
||||||
|
self,
|
||||||
|
hidden_sizes: Sequence[int],
|
||||||
|
continuous_unbounded=False,
|
||||||
|
continuous_conditioned_sigma=False,
|
||||||
|
) -> Self:
|
||||||
|
return super()._with_actor_factory_default(
|
||||||
|
hidden_sizes,
|
||||||
|
continuous_unbounded=continuous_unbounded,
|
||||||
|
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
|
||||||
|
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(ContinuousActorType.DETERMINISTIC)
|
||||||
|
|
||||||
|
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
||||||
|
return super()._with_actor_factory_default(hidden_sizes)
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinCriticsFactory:
|
class _BuilderMixinCriticsFactory:
|
||||||
def __init__(self, num_critics: int):
|
def __init__(self, num_critics: int):
|
||||||
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
||||||
@ -238,7 +281,8 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic_factory_default(
|
def with_critic_factory_default(
|
||||||
self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
self: TBuilder,
|
||||||
|
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> TBuilder:
|
) -> TBuilder:
|
||||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
@ -256,7 +300,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def with_common_critic_factory_default(
|
def with_common_critic_factory_default(
|
||||||
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
self,
|
||||||
|
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> TBuilder:
|
) -> TBuilder:
|
||||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||||
for i in range(len(self._critic_factories)):
|
for i in range(len(self._critic_factories)):
|
||||||
@ -269,7 +314,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic1_factory_default(
|
def with_critic1_factory_default(
|
||||||
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
self,
|
||||||
|
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> TBuilder:
|
) -> TBuilder:
|
||||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
@ -281,7 +327,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def with_critic2_factory_default(
|
def with_critic2_factory_default(
|
||||||
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
self,
|
||||||
|
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||||
) -> TBuilder:
|
) -> TBuilder:
|
||||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||||
self._with_critic_factory_default(0, hidden_sizes)
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
@ -289,7 +336,9 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
|
|
||||||
|
|
||||||
class PPOExperimentBuilder(
|
class PPOExperimentBuilder(
|
||||||
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory,
|
RLExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
|
_BuilderMixinSingleCriticFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -299,12 +348,12 @@ class PPOExperimentBuilder(
|
|||||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||||
self._params: PPOParams = PPOParams()
|
self._params: PPOParams = PPOParams()
|
||||||
self._dist_fn = dist_fn
|
self._dist_fn = dist_fn
|
||||||
|
|
||||||
def with_ppo_params(self, params: PPOParams) -> "PPOExperimentBuilder":
|
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -316,12 +365,14 @@ class PPOExperimentBuilder(
|
|||||||
self._get_actor_factory(),
|
self._get_actor_factory(),
|
||||||
self._get_critic_factory(0),
|
self._get_critic_factory(0),
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
self._dist_fn
|
self._dist_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SACExperimentBuilder(
|
class SACExperimentBuilder(
|
||||||
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory,
|
RLExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
|
_BuilderMixinDualCriticFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -330,14 +381,51 @@ class SACExperimentBuilder(
|
|||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinDualCriticFactory.__init__(self)
|
_BuilderMixinDualCriticFactory.__init__(self)
|
||||||
self._params: SACParams = SACParams()
|
self._params: SACParams = SACParams()
|
||||||
|
|
||||||
def with_sac_params(self, params: SACParams) -> "SACExperimentBuilder":
|
def with_sac_params(self, params: SACParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return SACAgentFactory(self._params, self._sampling_config, self._get_actor_factory(),
|
return SACAgentFactory(
|
||||||
self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory())
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_actor_factory(),
|
||||||
|
self._get_critic_factory(0),
|
||||||
|
self._get_critic_factory(1),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TD3ExperimentBuilder(
|
||||||
|
RLExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||||
|
_BuilderMixinDualCriticFactory,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
):
|
||||||
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||||
|
_BuilderMixinDualCriticFactory.__init__(self)
|
||||||
|
self._params: TD3Params = TD3Params()
|
||||||
|
|
||||||
|
def with_td3_params(self, params: TD3Params) -> Self:
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
return TD3AgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_actor_factory(),
|
||||||
|
self._get_critic_factory(0),
|
||||||
|
self._get_critic_factory(1),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal, TypeAlias
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
|
|
||||||
TLogger = TensorboardLogger | WandbLogger
|
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -30,7 +30,7 @@ class DefaultLoggerFactory(LoggerFactory):
|
|||||||
wandb_project: str | None = None,
|
wandb_project: str | None = None,
|
||||||
):
|
):
|
||||||
if logger_type == "wandb" and wandb_project is None:
|
if logger_type == "wandb" and wandb_project is None:
|
||||||
raise ValueError("Must provide 'wand_project'")
|
raise ValueError("Must provide 'wandb_project'")
|
||||||
self.log_dir = log_dir
|
self.log_dir = log_dir
|
||||||
self.logger_type = logger_type
|
self.logger_type = logger_type
|
||||||
self.wandb_project = wandb_project
|
self.wandb_project = wandb_project
|
||||||
|
|||||||
@ -1,16 +1,18 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TypeAlias
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net.continuous import ActorProb
|
from tianshou.utils.net import continuous
|
||||||
from tianshou.utils.net.continuous import Critic as ContinuousCritic
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
|
|
||||||
TDevice = str | int | torch.device
|
TDevice: TypeAlias = str | int | torch.device
|
||||||
|
|
||||||
|
|
||||||
def init_linear_orthogonal(module: torch.nn.Module):
|
def init_linear_orthogonal(module: torch.nn.Module):
|
||||||
@ -24,6 +26,11 @@ def init_linear_orthogonal(module: torch.nn.Module):
|
|||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousActorType:
|
||||||
|
GAUSSIAN = "gaussian"
|
||||||
|
DETERMINISTIC = "deterministic"
|
||||||
|
|
||||||
|
|
||||||
class ActorFactory(ABC):
|
class ActorFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||||
@ -47,30 +54,36 @@ class ActorFactory(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultActorFactory(ActorFactory):
|
class DefaultActorFactory(ActorFactory):
|
||||||
|
"""An actor factory which, depending on the type of environment, creates a suitable MLP-based policy."""
|
||||||
|
|
||||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
continuous_actor_type: ContinuousActorType,
|
||||||
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||||
continuous_unbounded=False,
|
continuous_unbounded=False,
|
||||||
continuous_conditioned_sigma=False,
|
continuous_conditioned_sigma=False,
|
||||||
):
|
):
|
||||||
|
self.continuous_actor_type = continuous_actor_type
|
||||||
self.continuous_unbounded = continuous_unbounded
|
self.continuous_unbounded = continuous_unbounded
|
||||||
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
"""
|
|
||||||
An actor factory which, depending on the type of environment, creates a suitable MLP-based policy
|
|
||||||
"""
|
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||||
env_type = envs.get_type()
|
env_type = envs.get_type()
|
||||||
if env_type == EnvType.CONTINUOUS:
|
if env_type == EnvType.CONTINUOUS:
|
||||||
factory = ContinuousActorProbFactory(
|
match self.continuous_actor_type:
|
||||||
self.hidden_sizes,
|
case ContinuousActorType.GAUSSIAN:
|
||||||
unbounded=self.continuous_unbounded,
|
factory = ContinuousActorFactoryGaussian(
|
||||||
conditioned_sigma=self.continuous_conditioned_sigma,
|
self.hidden_sizes,
|
||||||
)
|
unbounded=self.continuous_unbounded,
|
||||||
|
conditioned_sigma=self.continuous_conditioned_sigma,
|
||||||
|
)
|
||||||
|
case ContinuousActorType.DETERMINISTIC:
|
||||||
|
factory = ContinuousActorFactoryDeterministic(self.hidden_sizes)
|
||||||
|
case _:
|
||||||
|
raise ValueError(self.continuous_actor_type)
|
||||||
return factory.create_module(envs, device)
|
return factory.create_module(envs, device)
|
||||||
elif env_type == EnvType.DISCRETE:
|
elif env_type == EnvType.DISCRETE:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -82,8 +95,25 @@ class ContinuousActorFactory(ActorFactory, ABC):
|
|||||||
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousActorFactoryDeterministic(ContinuousActorFactory):
|
||||||
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
class ContinuousActorProbFactory(ContinuousActorFactory):
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||||
|
net_a = Net(
|
||||||
|
envs.get_observation_shape(),
|
||||||
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return continuous.Actor(
|
||||||
|
net_a,
|
||||||
|
envs.get_action_shape(),
|
||||||
|
hidden_sizes=(),
|
||||||
|
device=device,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousActorFactoryGaussian(ContinuousActorFactory):
|
||||||
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
self.unbounded = unbounded
|
self.unbounded = unbounded
|
||||||
@ -96,7 +126,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
|
|||||||
activation=nn.Tanh,
|
activation=nn.Tanh,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
actor = ActorProb(
|
actor = continuous.ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
envs.get_action_shape(),
|
envs.get_action_shape(),
|
||||||
unbounded=self.unbounded,
|
unbounded=self.unbounded,
|
||||||
@ -155,6 +185,54 @@ class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
|||||||
activation=nn.Tanh,
|
activation=nn.Tanh,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
critic = ContinuousCritic(net_c, device=device).to(device)
|
critic = continuous.Critic(net_c, device=device).to(device)
|
||||||
init_linear_orthogonal(critic)
|
init_linear_orthogonal(critic)
|
||||||
return critic
|
return critic
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModuleOpt:
|
||||||
|
module: torch.nn.Module
|
||||||
|
optim: torch.optim.Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActorCriticModuleOpt:
|
||||||
|
actor_critic_module: ActorCritic
|
||||||
|
optim: torch.optim.Optimizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def actor(self):
|
||||||
|
return self.actor_critic_module.actor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def critic(self):
|
||||||
|
return self.actor_critic_module.critic
|
||||||
|
|
||||||
|
|
||||||
|
class ActorModuleOptFactory:
|
||||||
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||||
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
|
opt = self.optim_factory.create_optimizer(actor, lr)
|
||||||
|
return ModuleOpt(actor, opt)
|
||||||
|
|
||||||
|
|
||||||
|
class CriticModuleOptFactory:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
use_action: bool,
|
||||||
|
):
|
||||||
|
self.critic_factory = critic_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
self.use_action = use_action
|
||||||
|
|
||||||
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||||
|
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
||||||
|
opt = self.optim_factory.create_optimizer(critic, lr)
|
||||||
|
return ModuleOpt(critic, opt)
|
||||||
|
|||||||
@ -1,13 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Iterable
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerFactory(ABC):
|
class OptimizerFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -38,5 +34,3 @@ class AdamOptimizerFactory(OptimizerFactory):
|
|||||||
eps=self.eps,
|
eps=self.eps,
|
||||||
weight_decay=self.weight_decay,
|
weight_decay=self.weight_decay,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
24
tianshou/highlevel/params/env_param.py
Normal file
24
tianshou/highlevel/params/env_param.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""Factories for the generation of environment-dependent parameters."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class FloatEnvParamFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_param(self, envs: Environments) -> float:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MaxActionScaledFloatEnvParamFactory(FloatEnvParamFactory):
|
||||||
|
def __init__(self, value: float):
|
||||||
|
""":param value: value with which to scale the max action value"""
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def create_param(self, envs: Environments) -> float:
|
||||||
|
envs.get_type().assert_continuous(self)
|
||||||
|
envs: ContinuousEnvironments
|
||||||
|
return envs.max_action * self.value
|
||||||
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
|
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
|
||||||
|
|||||||
25
tianshou/highlevel/params/noise.py
Normal file
25
tianshou/highlevel/params/noise.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||||
|
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_noise(self, envs: Environments) -> BaseNoise:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MaxActionScaledGaussianNoiseFactory(NoiseFactory):
|
||||||
|
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
|
||||||
|
|
||||||
|
This factory can only be applied to continuous action spaces.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, std_fraction: float):
|
||||||
|
self.std_fraction = std_fraction
|
||||||
|
|
||||||
|
def create_noise(self, envs: Environments) -> BaseNoise:
|
||||||
|
envs.get_type().assert_continuous(self)
|
||||||
|
envs: ContinuousEnvironments
|
||||||
|
return GaussianNoise(sigma=envs.max_action * self.std_fraction)
|
||||||
@ -1,21 +1,23 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Dict, Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||||
|
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||||
|
from tianshou.highlevel.params.noise import NoiseFactory
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformer(ABC):
|
class ParamTransformer(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any:
|
def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
|
||||||
value = d[key]
|
value = d[key]
|
||||||
if drop:
|
if drop:
|
||||||
del d[key]
|
del d[key]
|
||||||
@ -24,7 +26,7 @@ class ParamTransformer(ABC):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Params:
|
class Params:
|
||||||
def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]:
|
def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]:
|
||||||
d = asdict(self)
|
d = asdict(self)
|
||||||
for transformer in transformers:
|
for transformer in transformers:
|
||||||
transformer.transform(d)
|
transformer.transform(d)
|
||||||
@ -34,6 +36,7 @@ class Params:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PGParams(Params):
|
class PGParams(Params):
|
||||||
"""Config of general policy-gradient algorithms."""
|
"""Config of general policy-gradient algorithms."""
|
||||||
|
|
||||||
discount_factor: float = 0.99
|
discount_factor: float = 0.99
|
||||||
reward_normalization: bool = False
|
reward_normalization: bool = False
|
||||||
deterministic_eval: bool = False
|
deterministic_eval: bool = False
|
||||||
@ -53,6 +56,7 @@ class A2CParams(PGParams):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PPOParams(A2CParams):
|
class PPOParams(A2CParams):
|
||||||
"""PPO specific config."""
|
"""PPO specific config."""
|
||||||
|
|
||||||
eps_clip: float = 0.2
|
eps_clip: float = 0.2
|
||||||
dual_clip: float | None = None
|
dual_clip: float | None = None
|
||||||
value_clip: bool = False
|
value_clip: bool = False
|
||||||
@ -63,7 +67,17 @@ class PPOParams(A2CParams):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SACParams(Params):
|
class ActorAndDualCriticsParams(Params):
|
||||||
|
actor_lr: float = 1e-3
|
||||||
|
critic1_lr: float = 1e-3
|
||||||
|
critic2_lr: float = 1e-3
|
||||||
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SACParams(ActorAndDualCriticsParams):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
gamma: float = 0.99
|
gamma: float = 0.99
|
||||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||||
@ -72,9 +86,16 @@ class SACParams(Params):
|
|||||||
deterministic_eval: bool = True
|
deterministic_eval: bool = True
|
||||||
action_scaling: bool = True
|
action_scaling: bool = True
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
actor_lr: float = 1e-3
|
|
||||||
critic1_lr: float = 1e-3
|
|
||||||
critic2_lr: float = 1e-3
|
@dataclass
|
||||||
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
class TD3Params(ActorAndDualCriticsParams):
|
||||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
tau: float = 0.005
|
||||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
gamma: float = 0.99
|
||||||
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
||||||
|
policy_noise: float | FloatEnvParamFactory = 0.2
|
||||||
|
noise_clip: float | FloatEnvParamFactory = 0.5
|
||||||
|
update_actor_freq: int = 2
|
||||||
|
estimation_step: int = 1
|
||||||
|
action_scaling: bool = True
|
||||||
|
action_bound_method: Literal["clip"] | None = "clip"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user