Add high-level API support for TD3
* Created mixins for agent factories to reduce code duplication * Further factorised params & mixins for experiment factories * Additional parameter abstractions * Implement high-level MuJoCo TD3 example
This commit is contained in:
parent
6a739384ee
commit
e993425aa1
@ -66,7 +66,7 @@ def main(
|
||||
|
||||
experiment = (
|
||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
|
||||
.with_ppo_params(
|
||||
.with_params(
|
||||
PPOParams(
|
||||
discount_factor=gamma,
|
||||
gae_lambda=gae_lambda,
|
||||
|
||||
@ -7,13 +7,13 @@ from collections.abc import Sequence
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.params.policy_params import SACParams
|
||||
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
RLExperimentConfig,
|
||||
SACExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
|
||||
from tianshou.highlevel.params.policy_params import SACParams
|
||||
|
||||
|
||||
def main(
|
||||
@ -70,7 +70,9 @@ def main(
|
||||
),
|
||||
)
|
||||
.with_actor_factory_default(
|
||||
hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True,
|
||||
hidden_sizes,
|
||||
continuous_unbounded=True,
|
||||
continuous_conditioned_sigma=True,
|
||||
)
|
||||
.with_common_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
|
||||
85
examples/mujoco/mujoco_td3_hl.py
Normal file
85
examples/mujoco/mujoco_td3_hl.py
Normal file
@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
RLExperimentConfig,
|
||||
TD3ExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.env_param import MaxActionScaledFloatEnvParamFactory
|
||||
from tianshou.highlevel.params.noise import MaxActionScaledGaussianNoiseFactory
|
||||
from tianshou.highlevel.params.policy_params import TD3Params
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
buffer_size: int = 1000000,
|
||||
hidden_sizes: Sequence[int] = (256, 256),
|
||||
actor_lr: float = 3e-4,
|
||||
critic_lr: float = 3e-4,
|
||||
gamma: float = 0.99,
|
||||
tau: float = 0.005,
|
||||
exploration_noise: float = 0.1,
|
||||
policy_noise: float = 0.2,
|
||||
noise_clip: float = 0.5,
|
||||
update_actor_freq: int = 2,
|
||||
start_timesteps: int = 25000,
|
||||
epoch: int = 200,
|
||||
step_per_epoch: int = 5000,
|
||||
step_per_collect: int = 1,
|
||||
update_per_step: int = 1,
|
||||
n_step: int = 1,
|
||||
batch_size: int = 256,
|
||||
training_num: int = 1,
|
||||
test_num: int = 10,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
num_train_envs=training_num,
|
||||
num_test_envs=test_num,
|
||||
buffer_size=buffer_size,
|
||||
batch_size=batch_size,
|
||||
step_per_collect=step_per_collect,
|
||||
update_per_step=update_per_step,
|
||||
start_timesteps=start_timesteps,
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
|
||||
experiment = (
|
||||
TD3ExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
.with_td3_params(
|
||||
TD3Params(
|
||||
tau=tau,
|
||||
gamma=gamma,
|
||||
estimation_step=n_step,
|
||||
update_actor_freq=update_actor_freq,
|
||||
noise_clip=MaxActionScaledFloatEnvParamFactory(noise_clip),
|
||||
policy_noise=MaxActionScaledFloatEnvParamFactory(policy_noise),
|
||||
exploration_noise=MaxActionScaledGaussianNoiseFactory(exploration_noise),
|
||||
actor_lr=actor_lr,
|
||||
critic1_lr=critic_lr,
|
||||
critic2_lr=critic_lr,
|
||||
),
|
||||
)
|
||||
.with_actor_factory_default(hidden_sizes)
|
||||
.with_common_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
@ -1,21 +1,35 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||
from tianshou.highlevel.module import (
|
||||
ActorCriticModuleOpt,
|
||||
ActorFactory,
|
||||
ActorModuleOptFactory,
|
||||
CriticFactory,
|
||||
CriticModuleOptFactory,
|
||||
ModuleOpt,
|
||||
TDevice,
|
||||
)
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams
|
||||
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
|
||||
from tianshou.highlevel.params.noise import NoiseFactory
|
||||
from tianshou.highlevel.params.policy_params import (
|
||||
ParamTransformer,
|
||||
PPOParams,
|
||||
SACParams,
|
||||
TD3Params,
|
||||
)
|
||||
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
@ -135,7 +149,7 @@ class ParamTransformerDrop(ParamTransformer):
|
||||
def __init__(self, *keys: str):
|
||||
self.keys = keys
|
||||
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
for k in self.keys:
|
||||
del kwargs[k]
|
||||
|
||||
@ -144,12 +158,94 @@ class ParamTransformerLRScheduler(ParamTransformer):
|
||||
def __init__(self, optim: torch.optim.Optimizer):
|
||||
self.optim = optim
|
||||
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
|
||||
kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None
|
||||
kwargs["lr_scheduler"] = (
|
||||
factory.create_scheduler(self.optim) if factory is not None else None
|
||||
)
|
||||
|
||||
|
||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
class _ActorMixin:
|
||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
|
||||
|
||||
def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
|
||||
|
||||
|
||||
class _ActorCriticMixin:
|
||||
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
):
|
||||
self.actor_factory = actor_factory
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optim_factory
|
||||
self.critic_use_action = critic_use_action
|
||||
|
||||
def create_actor_critic_module_opt(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
lr: float,
|
||||
) -> ActorCriticModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||
return ActorCriticModuleOpt(actor_critic, optim)
|
||||
|
||||
|
||||
class _ActorAndCriticMixin(_ActorMixin):
|
||||
def __init__(
|
||||
self,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
):
|
||||
super().__init__(actor_factory, optim_factory)
|
||||
self.critic_module_opt_factory = CriticModuleOptFactory(
|
||||
critic_factory,
|
||||
optim_factory,
|
||||
critic_use_action,
|
||||
)
|
||||
|
||||
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
||||
|
||||
|
||||
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
||||
def __init__(
|
||||
self,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
):
|
||||
super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action)
|
||||
self.critic2_module_opt_factory = CriticModuleOptFactory(
|
||||
critic2_factory,
|
||||
optim_factory,
|
||||
critic_use_action,
|
||||
)
|
||||
|
||||
def create_critic2_module_opt(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
lr: float,
|
||||
) -> ModuleOpt:
|
||||
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
||||
|
||||
|
||||
class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
||||
def __init__(
|
||||
self,
|
||||
params: PPOParams,
|
||||
@ -160,27 +256,29 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
self.optimizer_factory = optimizer_factory
|
||||
self.critic_factory = critic_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.config = params
|
||||
_ActorCriticMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optimizer_factory,
|
||||
critic_use_action=False,
|
||||
)
|
||||
self.params = params
|
||||
self.dist_fn = dist_fn
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=False)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
|
||||
kwargs = self.config.create_kwargs(
|
||||
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerDrop("lr"),
|
||||
ParamTransformerLRScheduler(optim))
|
||||
ParamTransformerLRScheduler(actor_critic.optim),
|
||||
)
|
||||
return PPOPolicy(
|
||||
actor=actor,
|
||||
critic=critic,
|
||||
optim=optim,
|
||||
actor=actor_critic.actor,
|
||||
critic=actor_critic.critic,
|
||||
optim=actor_critic.optim,
|
||||
dist_fn=self.dist_fn,
|
||||
action_space=envs.get_action_space(),
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -190,7 +288,7 @@ class ParamTransformerAlpha(ParamTransformer):
|
||||
self.optim_factory = optim_factory
|
||||
self.device = device
|
||||
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
key = "alpha"
|
||||
alpha = self.get(kwargs, key)
|
||||
if isinstance(alpha, AutoAlphaFactory):
|
||||
@ -198,13 +296,17 @@ class ParamTransformerAlpha(ParamTransformer):
|
||||
|
||||
|
||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||
def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]):
|
||||
def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]):
|
||||
self.optim_key_list = optim_key_list
|
||||
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
lr_schedulers = []
|
||||
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True)
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = self.get(
|
||||
kwargs,
|
||||
lr_scheduler_factory_key,
|
||||
drop=True,
|
||||
)
|
||||
if lr_scheduler_factory is not None:
|
||||
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
||||
match len(lr_schedulers):
|
||||
@ -217,7 +319,7 @@ class ParamTransformerMultiLRScheduler(ParamTransformer):
|
||||
kwargs["lr_scheduler"] = lr_scheduler
|
||||
|
||||
|
||||
class SACAgentFactory(OffpolicyAgentFactory):
|
||||
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
def __init__(
|
||||
self,
|
||||
params: SACParams,
|
||||
@ -228,35 +330,114 @@ class SACAgentFactory(OffpolicyAgentFactory):
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
self.critic2_factory = critic2_factory
|
||||
self.critic1_factory = critic1_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.optim_factory = optim_factory
|
||||
_ActorAndDualCriticsMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
critic1_factory,
|
||||
critic2_factory,
|
||||
optim_factory,
|
||||
critic_use_action=True,
|
||||
)
|
||||
self.params = params
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
|
||||
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
|
||||
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr)
|
||||
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr)
|
||||
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr)
|
||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||
ParamTransformerMultiLRScheduler([
|
||||
(actor_optim, "actor_lr_scheduler_factory"),
|
||||
(critic1_optim, "critic1_lr_scheduler_factory"),
|
||||
(critic2_optim, "critic2_lr_scheduler_factory")]
|
||||
ParamTransformerMultiLRScheduler(
|
||||
[
|
||||
(actor.optim, "actor_lr_scheduler_factory"),
|
||||
(critic1.optim, "critic1_lr_scheduler_factory"),
|
||||
(critic2.optim, "critic2_lr_scheduler_factory"),
|
||||
],
|
||||
),
|
||||
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device))
|
||||
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device),
|
||||
)
|
||||
return SACPolicy(
|
||||
actor=actor,
|
||||
actor_optim=actor_optim,
|
||||
critic=critic1,
|
||||
critic_optim=critic1_optim,
|
||||
critic2=critic2,
|
||||
critic2_optim=critic2_optim,
|
||||
actor=actor.module,
|
||||
actor_optim=actor.optim,
|
||||
critic=critic1.module,
|
||||
critic_optim=critic1.optim,
|
||||
critic2=critic2.module,
|
||||
critic2_optim=critic2.optim,
|
||||
action_space=envs.get_action_space(),
|
||||
observation_space=envs.get_observation_space(),
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class ParamTransformerNoiseFactory(ParamTransformer):
|
||||
def __init__(self, key: str, envs: Environments):
|
||||
self.key = key
|
||||
self.envs = envs
|
||||
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
value = kwargs[self.key]
|
||||
if isinstance(value, NoiseFactory):
|
||||
kwargs[self.key] = value.create_noise(self.envs)
|
||||
|
||||
|
||||
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
||||
def __init__(self, key: str, envs: Environments):
|
||||
self.key = key
|
||||
self.envs = envs
|
||||
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
value = kwargs[self.key]
|
||||
if isinstance(value, FloatEnvParamFactory):
|
||||
kwargs[self.key] = value.create_param(self.envs)
|
||||
|
||||
|
||||
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
def __init__(
|
||||
self,
|
||||
params: TD3Params,
|
||||
sampling_config: RLSamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic1_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
_ActorAndDualCriticsMixin.__init__(
|
||||
self,
|
||||
actor_factory,
|
||||
critic1_factory,
|
||||
critic2_factory,
|
||||
optim_factory,
|
||||
critic_use_action=True,
|
||||
)
|
||||
self.params = params
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
||||
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
||||
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
||||
ParamTransformerMultiLRScheduler(
|
||||
[
|
||||
(actor.optim, "actor_lr_scheduler_factory"),
|
||||
(critic1.optim, "critic1_lr_scheduler_factory"),
|
||||
(critic2.optim, "critic2_lr_scheduler_factory"),
|
||||
],
|
||||
),
|
||||
ParamTransformerNoiseFactory("exploration_noise", envs),
|
||||
ParamTransformerFloatEnvParamFactory("policy_noise", envs),
|
||||
ParamTransformerFloatEnvParamFactory("noise_clip", envs),
|
||||
)
|
||||
return TD3Policy(
|
||||
actor=actor.module,
|
||||
actor_optim=actor.optim,
|
||||
critic=critic1.module,
|
||||
critic_optim=critic1.optim,
|
||||
critic2=critic2.module,
|
||||
critic2_optim=critic2.optim,
|
||||
action_space=envs.get_action_space(),
|
||||
observation_space=envs.get_observation_space(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
class RLSamplingConfig:
|
||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||
|
||||
# TODO: What are reasonable defaults?
|
||||
num_epochs: int = 100
|
||||
step_per_epoch: int = 30000
|
||||
batch_size: int = 64
|
||||
|
||||
@ -20,6 +20,14 @@ class EnvType(Enum):
|
||||
def is_continuous(self):
|
||||
return self == EnvType.CONTINUOUS
|
||||
|
||||
def assert_continuous(self, requiring_entity: Any):
|
||||
if not self.is_continuous():
|
||||
raise AssertionError(f"{requiring_entity} requires continuous environments")
|
||||
|
||||
def assert_discrete(self, requiring_entity: Any):
|
||||
if not self.is_discrete():
|
||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||
|
||||
|
||||
class Environments(ABC):
|
||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
@ -28,7 +36,10 @@ class Environments(ABC):
|
||||
self.test_envs = test_envs
|
||||
|
||||
def info(self) -> dict[str, Any]:
|
||||
return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()}
|
||||
return {
|
||||
"action_shape": self.get_action_shape(),
|
||||
"state_shape": self.get_observation_shape(),
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_action_shape(self) -> TShape:
|
||||
@ -81,7 +92,7 @@ class ContinuousEnvironments(Environments):
|
||||
def get_observation_shape(self) -> TShape:
|
||||
return self.state_shape
|
||||
|
||||
def get_type(self):
|
||||
def get_type(self) -> EnvType:
|
||||
return EnvType.CONTINUOUS
|
||||
|
||||
|
||||
|
||||
@ -1,25 +1,31 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from pprint import pprint
|
||||
from typing import Generic, TypeVar, Callable
|
||||
from typing import Generic, Self, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, SACAgentFactory
|
||||
from tianshou.highlevel.agent import (
|
||||
AgentFactory,
|
||||
PPOAgentFactory,
|
||||
SACAgentFactory,
|
||||
TD3AgentFactory,
|
||||
)
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ActorFactory,
|
||||
ContinuousActorType,
|
||||
CriticFactory,
|
||||
DefaultActorFactory,
|
||||
DefaultCriticFactory,
|
||||
)
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
||||
from tianshou.highlevel.params.policy_params import PPOParams, SACParams
|
||||
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.trainer import BaseTrainer
|
||||
@ -150,7 +156,10 @@ class RLExperimentBuilder:
|
||||
return self
|
||||
|
||||
def with_optim_factory_default(
|
||||
self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0,
|
||||
self: TBuilder,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-08,
|
||||
weight_decay=0,
|
||||
) -> TBuilder:
|
||||
"""Configures the use of the default optimizer, Adam, with the given parameters.
|
||||
|
||||
@ -174,12 +183,16 @@ class RLExperimentBuilder:
|
||||
|
||||
def build(self) -> RLExperiment:
|
||||
return RLExperiment(
|
||||
self._config, self._env_factory, self._create_agent_factory(), self._logger_factory,
|
||||
self._config,
|
||||
self._env_factory,
|
||||
self._create_agent_factory(),
|
||||
self._logger_factory,
|
||||
)
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory:
|
||||
def __init__(self):
|
||||
def __init__(self, continuous_actor_type: ContinuousActorType):
|
||||
self._continuous_actor_type = continuous_actor_type
|
||||
self._actor_factory: ActorFactory | None = None
|
||||
|
||||
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
|
||||
@ -187,7 +200,7 @@ class _BuilderMixinActorFactory:
|
||||
self._actor_factory = actor_factory
|
||||
return self
|
||||
|
||||
def with_actor_factory_default(
|
||||
def _with_actor_factory_default(
|
||||
self: TBuilder,
|
||||
hidden_sizes: Sequence[int],
|
||||
continuous_unbounded=False,
|
||||
@ -195,6 +208,7 @@ class _BuilderMixinActorFactory:
|
||||
) -> TBuilder:
|
||||
self: TBuilder | _BuilderMixinActorFactory
|
||||
self._actor_factory = DefaultActorFactory(
|
||||
self._continuous_actor_type,
|
||||
hidden_sizes,
|
||||
continuous_unbounded=continuous_unbounded,
|
||||
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
||||
@ -203,11 +217,40 @@ class _BuilderMixinActorFactory:
|
||||
|
||||
def _get_actor_factory(self):
|
||||
if self._actor_factory is None:
|
||||
return DefaultActorFactory()
|
||||
return DefaultActorFactory(self._continuous_actor_type)
|
||||
else:
|
||||
return self._actor_factory
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||
|
||||
def with_actor_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int],
|
||||
continuous_unbounded=False,
|
||||
continuous_conditioned_sigma=False,
|
||||
) -> Self:
|
||||
return super()._with_actor_factory_default(
|
||||
hidden_sizes,
|
||||
continuous_unbounded=continuous_unbounded,
|
||||
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
||||
)
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
|
||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(ContinuousActorType.DETERMINISTIC)
|
||||
|
||||
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
||||
return super()._with_actor_factory_default(hidden_sizes)
|
||||
|
||||
|
||||
class _BuilderMixinCriticsFactory:
|
||||
def __init__(self, num_critics: int):
|
||||
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
||||
@ -238,7 +281,8 @@ class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
return self
|
||||
|
||||
def with_critic_factory_default(
|
||||
self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
self: TBuilder,
|
||||
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
@ -256,7 +300,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
return self
|
||||
|
||||
def with_common_critic_factory_default(
|
||||
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
for i in range(len(self._critic_factories)):
|
||||
@ -269,7 +314,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
return self
|
||||
|
||||
def with_critic1_factory_default(
|
||||
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
@ -281,7 +327,8 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
return self
|
||||
|
||||
def with_critic2_factory_default(
|
||||
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
@ -289,7 +336,9 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
|
||||
|
||||
class PPOExperimentBuilder(
|
||||
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory,
|
||||
RLExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
_BuilderMixinSingleCriticFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@ -299,12 +348,12 @@ class PPOExperimentBuilder(
|
||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||
self._params: PPOParams = PPOParams()
|
||||
self._dist_fn = dist_fn
|
||||
|
||||
def with_ppo_params(self, params: PPOParams) -> "PPOExperimentBuilder":
|
||||
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@ -316,12 +365,14 @@ class PPOExperimentBuilder(
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
self._dist_fn
|
||||
self._dist_fn,
|
||||
)
|
||||
|
||||
|
||||
class SACExperimentBuilder(
|
||||
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory,
|
||||
RLExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
_BuilderMixinDualCriticFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@ -330,14 +381,51 @@ class SACExperimentBuilder(
|
||||
sampling_config: RLSamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self)
|
||||
self._params: SACParams = SACParams()
|
||||
|
||||
def with_sac_params(self, params: SACParams) -> "SACExperimentBuilder":
|
||||
def with_sac_params(self, params: SACParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return SACAgentFactory(self._params, self._sampling_config, self._get_actor_factory(),
|
||||
self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory())
|
||||
return SACAgentFactory(
|
||||
self._params,
|
||||
self._sampling_config,
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_critic_factory(1),
|
||||
self._get_optim_factory(),
|
||||
)
|
||||
|
||||
|
||||
class TD3ExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||
_BuilderMixinDualCriticFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self)
|
||||
self._params: TD3Params = TD3Params()
|
||||
|
||||
def with_td3_params(self, params: TD3Params) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return TD3AgentFactory(
|
||||
self._params,
|
||||
self._sampling_config,
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_critic_factory(1),
|
||||
self._get_optim_factory(),
|
||||
)
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
|
||||
TLogger = TensorboardLogger | WandbLogger
|
||||
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -30,7 +30,7 @@ class DefaultLoggerFactory(LoggerFactory):
|
||||
wandb_project: str | None = None,
|
||||
):
|
||||
if logger_type == "wandb" and wandb_project is None:
|
||||
raise ValueError("Must provide 'wand_project'")
|
||||
raise ValueError("Must provide 'wandb_project'")
|
||||
self.log_dir = log_dir
|
||||
self.logger_type = logger_type
|
||||
self.wandb_project = wandb_project
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
from tianshou.utils.net.continuous import Critic as ContinuousCritic
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.utils.net import continuous
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
|
||||
TDevice = str | int | torch.device
|
||||
TDevice: TypeAlias = str | int | torch.device
|
||||
|
||||
|
||||
def init_linear_orthogonal(module: torch.nn.Module):
|
||||
@ -24,6 +26,11 @@ def init_linear_orthogonal(module: torch.nn.Module):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
class ContinuousActorType:
|
||||
GAUSSIAN = "gaussian"
|
||||
DETERMINISTIC = "deterministic"
|
||||
|
||||
|
||||
class ActorFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
@ -47,30 +54,36 @@ class ActorFactory(ABC):
|
||||
|
||||
|
||||
class DefaultActorFactory(ActorFactory):
|
||||
"""An actor factory which, depending on the type of environment, creates a suitable MLP-based policy."""
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
continuous_actor_type: ContinuousActorType,
|
||||
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||
continuous_unbounded=False,
|
||||
continuous_conditioned_sigma=False,
|
||||
):
|
||||
self.continuous_actor_type = continuous_actor_type
|
||||
self.continuous_unbounded = continuous_unbounded
|
||||
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
"""
|
||||
An actor factory which, depending on the type of environment, creates a suitable MLP-based policy
|
||||
"""
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
env_type = envs.get_type()
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
factory = ContinuousActorProbFactory(
|
||||
self.hidden_sizes,
|
||||
unbounded=self.continuous_unbounded,
|
||||
conditioned_sigma=self.continuous_conditioned_sigma,
|
||||
)
|
||||
match self.continuous_actor_type:
|
||||
case ContinuousActorType.GAUSSIAN:
|
||||
factory = ContinuousActorFactoryGaussian(
|
||||
self.hidden_sizes,
|
||||
unbounded=self.continuous_unbounded,
|
||||
conditioned_sigma=self.continuous_conditioned_sigma,
|
||||
)
|
||||
case ContinuousActorType.DETERMINISTIC:
|
||||
factory = ContinuousActorFactoryDeterministic(self.hidden_sizes)
|
||||
case _:
|
||||
raise ValueError(self.continuous_actor_type)
|
||||
return factory.create_module(envs, device)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
raise NotImplementedError
|
||||
@ -82,8 +95,25 @@ class ContinuousActorFactory(ActorFactory, ABC):
|
||||
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
||||
|
||||
|
||||
class ContinuousActorFactoryDeterministic(ContinuousActorFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
net_a = Net(
|
||||
envs.get_observation_shape(),
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
device=device,
|
||||
)
|
||||
return continuous.Actor(
|
||||
net_a,
|
||||
envs.get_action_shape(),
|
||||
hidden_sizes=(),
|
||||
device=device,
|
||||
).to(device)
|
||||
|
||||
|
||||
class ContinuousActorFactoryGaussian(ContinuousActorFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.unbounded = unbounded
|
||||
@ -96,7 +126,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
)
|
||||
actor = ActorProb(
|
||||
actor = continuous.ActorProb(
|
||||
net_a,
|
||||
envs.get_action_shape(),
|
||||
unbounded=self.unbounded,
|
||||
@ -155,6 +185,54 @@ class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
)
|
||||
critic = ContinuousCritic(net_c, device=device).to(device)
|
||||
critic = continuous.Critic(net_c, device=device).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleOpt:
|
||||
module: torch.nn.Module
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorCriticModuleOpt:
|
||||
actor_critic_module: ActorCritic
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
@property
|
||||
def actor(self):
|
||||
return self.actor_critic_module.actor
|
||||
|
||||
@property
|
||||
def critic(self):
|
||||
return self.actor_critic_module.critic
|
||||
|
||||
|
||||
class ActorModuleOptFactory:
|
||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||
self.actor_factory = actor_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
opt = self.optim_factory.create_optimizer(actor, lr)
|
||||
return ModuleOpt(actor, opt)
|
||||
|
||||
|
||||
class CriticModuleOptFactory:
|
||||
def __init__(
|
||||
self,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
use_action: bool,
|
||||
):
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optim_factory
|
||||
self.use_action = use_action
|
||||
|
||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
||||
opt = self.optim_factory.create_optimizer(critic, lr)
|
||||
return ModuleOpt(critic, opt)
|
||||
|
||||
@ -1,13 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
|
||||
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||
|
||||
|
||||
class OptimizerFactory(ABC):
|
||||
@abstractmethod
|
||||
@ -38,5 +34,3 @@ class AdamOptimizerFactory(OptimizerFactory):
|
||||
eps=self.eps,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
|
||||
|
||||
24
tianshou/highlevel/params/env_param.py
Normal file
24
tianshou/highlevel/params/env_param.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""Factories for the generation of environment-dependent parameters."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TypeVar
|
||||
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class FloatEnvParamFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_param(self, envs: Environments) -> float:
|
||||
pass
|
||||
|
||||
|
||||
class MaxActionScaledFloatEnvParamFactory(FloatEnvParamFactory):
|
||||
def __init__(self, value: float):
|
||||
""":param value: value with which to scale the max action value"""
|
||||
self.value = value
|
||||
|
||||
def create_param(self, envs: Environments) -> float:
|
||||
envs.get_type().assert_continuous(self)
|
||||
envs: ContinuousEnvironments
|
||||
return envs.max_action * self.value
|
||||
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
|
||||
|
||||
25
tianshou/highlevel/params/noise.py
Normal file
25
tianshou/highlevel/params/noise.py
Normal file
@ -0,0 +1,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
||||
|
||||
|
||||
class NoiseFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_noise(self, envs: Environments) -> BaseNoise:
|
||||
pass
|
||||
|
||||
|
||||
class MaxActionScaledGaussianNoiseFactory(NoiseFactory):
|
||||
"""Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value.
|
||||
|
||||
This factory can only be applied to continuous action spaces.
|
||||
"""
|
||||
|
||||
def __init__(self, std_fraction: float):
|
||||
self.std_fraction = std_fraction
|
||||
|
||||
def create_noise(self, envs: Environments) -> BaseNoise:
|
||||
envs.get_type().assert_continuous(self)
|
||||
envs: ContinuousEnvironments
|
||||
return GaussianNoise(sigma=envs.max_action * self.std_fraction)
|
||||
@ -1,21 +1,23 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Any, Literal
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
from tianshou.highlevel.params.noise import NoiseFactory
|
||||
|
||||
|
||||
class ParamTransformer(ABC):
|
||||
@abstractmethod
|
||||
def transform(self, kwargs: Dict[str, Any]) -> None:
|
||||
def transform(self, kwargs: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any:
|
||||
def get(d: dict[str, Any], key: str, drop: bool = False) -> Any:
|
||||
value = d[key]
|
||||
if drop:
|
||||
del d[key]
|
||||
@ -24,7 +26,7 @@ class ParamTransformer(ABC):
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]:
|
||||
def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]:
|
||||
d = asdict(self)
|
||||
for transformer in transformers:
|
||||
transformer.transform(d)
|
||||
@ -34,6 +36,7 @@ class Params:
|
||||
@dataclass
|
||||
class PGParams(Params):
|
||||
"""Config of general policy-gradient algorithms."""
|
||||
|
||||
discount_factor: float = 0.99
|
||||
reward_normalization: bool = False
|
||||
deterministic_eval: bool = False
|
||||
@ -53,6 +56,7 @@ class A2CParams(PGParams):
|
||||
@dataclass
|
||||
class PPOParams(A2CParams):
|
||||
"""PPO specific config."""
|
||||
|
||||
eps_clip: float = 0.2
|
||||
dual_clip: float | None = None
|
||||
value_clip: bool = False
|
||||
@ -63,7 +67,17 @@ class PPOParams(A2CParams):
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACParams(Params):
|
||||
class ActorAndDualCriticsParams(Params):
|
||||
actor_lr: float = 1e-3
|
||||
critic1_lr: float = 1e-3
|
||||
critic2_lr: float = 1e-3
|
||||
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACParams(ActorAndDualCriticsParams):
|
||||
tau: float = 0.005
|
||||
gamma: float = 0.99
|
||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||
@ -72,9 +86,16 @@ class SACParams(Params):
|
||||
deterministic_eval: bool = True
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
actor_lr: float = 1e-3
|
||||
critic1_lr: float = 1e-3
|
||||
critic2_lr: float = 1e-3
|
||||
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TD3Params(ActorAndDualCriticsParams):
|
||||
tau: float = 0.005
|
||||
gamma: float = 0.99
|
||||
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
||||
policy_noise: float | FloatEnvParamFactory = 0.2
|
||||
noise_clip: float | FloatEnvParamFactory = 0.5
|
||||
update_actor_freq: int = 2
|
||||
estimation_step: int = 1
|
||||
action_scaling: bool = True
|
||||
action_bound_method: Literal["clip"] | None = "clip"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user