Add high-level experiment builder interface
This commit is contained in:
parent
4d53d345d6
commit
37dc07e487
@ -9,18 +9,13 @@ from jsonargparse import CLI
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.agent import PPOAgentFactory, PPOConfig
|
from tianshou.highlevel.agent import PPOConfig
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperiment,
|
PPOExperimentBuilder,
|
||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
from tianshou.highlevel.optim import LinearLRSchedulerFactory
|
||||||
from tianshou.highlevel.module import (
|
|
||||||
ContinuousActorProbFactory,
|
|
||||||
ContinuousNetCriticFactory,
|
|
||||||
)
|
|
||||||
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -52,7 +47,6 @@ def main(
|
|||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
logger_factory = DefaultLoggerFactory()
|
|
||||||
|
|
||||||
sampling_config = RLSamplingConfig(
|
sampling_config = RLSamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
@ -70,7 +64,10 @@ def main(
|
|||||||
def dist_fn(*logits):
|
def dist_fn(*logits):
|
||||||
return Independent(Normal(*logits), 1)
|
return Independent(Normal(*logits), 1)
|
||||||
|
|
||||||
ppo_config = PPOConfig(
|
experiment = (
|
||||||
|
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||||
|
.with_ppo_params(
|
||||||
|
PPOConfig(
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
gae_lambda=gae_lambda,
|
gae_lambda=gae_lambda,
|
||||||
action_bound_method=bound_action_method,
|
action_bound_method=bound_action_method,
|
||||||
@ -83,24 +80,17 @@ def main(
|
|||||||
eps_clip=eps_clip,
|
eps_clip=eps_clip,
|
||||||
dual_clip=dual_clip,
|
dual_clip=dual_clip,
|
||||||
recompute_adv=recompute_adv,
|
recompute_adv=recompute_adv,
|
||||||
|
dist_fn=dist_fn,
|
||||||
|
lr=lr,
|
||||||
|
lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config)
|
||||||
|
if lr_decay
|
||||||
|
else None,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
actor_factory = ContinuousActorProbFactory(hidden_sizes)
|
.with_actor_factory_default(hidden_sizes)
|
||||||
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
.with_critic_factory_default(hidden_sizes)
|
||||||
optim_factory = AdamOptimizerFactory()
|
.build()
|
||||||
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if lr_decay else None
|
|
||||||
agent_factory = PPOAgentFactory(
|
|
||||||
ppo_config,
|
|
||||||
sampling_config,
|
|
||||||
actor_factory,
|
|
||||||
critic_factory,
|
|
||||||
optim_factory,
|
|
||||||
dist_fn,
|
|
||||||
lr,
|
|
||||||
lr_scheduler_factory,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
|
|
||||||
|
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,18 +7,12 @@ from collections.abc import Sequence
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig
|
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACConfig
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperiment,
|
|
||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
|
SACExperimentBuilder,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
|
||||||
from tianshou.highlevel.module import (
|
|
||||||
ContinuousActorProbFactory,
|
|
||||||
ContinuousNetCriticFactory,
|
|
||||||
)
|
|
||||||
from tianshou.highlevel.optim import AdamOptimizerFactory
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -45,7 +39,6 @@ def main(
|
|||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
||||||
logger_factory = DefaultLoggerFactory()
|
|
||||||
|
|
||||||
sampling_config = RLSamplingConfig(
|
sampling_config = RLSamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
@ -62,31 +55,25 @@ def main(
|
|||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
if auto_alpha:
|
experiment = (
|
||||||
alpha = DefaultAutoAlphaFactory(lr=alpha_lr)
|
SACExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||||
sac_config = SACConfig(
|
.with_sac_params(
|
||||||
|
SACConfig(
|
||||||
tau=tau,
|
tau=tau,
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
alpha=alpha,
|
alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha,
|
||||||
estimation_step=n_step,
|
estimation_step=n_step,
|
||||||
actor_lr=actor_lr,
|
actor_lr=actor_lr,
|
||||||
critic1_lr=critic_lr,
|
critic1_lr=critic_lr,
|
||||||
critic2_lr=critic_lr,
|
critic2_lr=critic_lr,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
|
.with_actor_factory_default(
|
||||||
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True,
|
||||||
optim_factory = AdamOptimizerFactory()
|
)
|
||||||
agent_factory = SACAgentFactory(
|
.with_common_critic_factory_default(hidden_sizes)
|
||||||
sac_config,
|
.build()
|
||||||
sampling_config,
|
|
||||||
actor_factory,
|
|
||||||
critic_factory,
|
|
||||||
critic_factory,
|
|
||||||
optim_factory,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
|
|
||||||
|
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,6 +161,9 @@ class PPOConfig(PGConfig):
|
|||||||
eps_clip: float = 0.2
|
eps_clip: float = 0.2
|
||||||
dual_clip: float | None = None
|
dual_clip: float | None = None
|
||||||
recompute_adv: bool = True
|
recompute_adv: bool = True
|
||||||
|
dist_fn: Callable = None
|
||||||
|
lr: float = 1e-3
|
||||||
|
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
|
||||||
|
|
||||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||||
@ -171,26 +174,20 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
|||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
dist_fn,
|
|
||||||
lr: float,
|
|
||||||
lr_scheduler_factory: LRSchedulerFactory | None = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config)
|
super().__init__(sampling_config)
|
||||||
self.optimizer_factory = optimizer_factory
|
self.optimizer_factory = optimizer_factory
|
||||||
self.critic_factory = critic_factory
|
self.critic_factory = critic_factory
|
||||||
self.actor_factory = actor_factory
|
self.actor_factory = actor_factory
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lr = lr
|
|
||||||
self.lr_scheduler_factory = lr_scheduler_factory
|
|
||||||
self.dist_fn = dist_fn
|
|
||||||
|
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
critic = self.critic_factory.create_module(envs, device, use_action=False)
|
critic = self.critic_factory.create_module(envs, device, use_action=False)
|
||||||
actor_critic = ActorCritic(actor, critic)
|
actor_critic = ActorCritic(actor, critic)
|
||||||
optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr)
|
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
|
||||||
if self.lr_scheduler_factory is not None:
|
if self.config.lr_scheduler_factory is not None:
|
||||||
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
|
lr_scheduler = self.config.lr_scheduler_factory.create_scheduler(optim)
|
||||||
else:
|
else:
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
return PPOPolicy(
|
return PPOPolicy(
|
||||||
@ -198,7 +195,7 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
|||||||
actor,
|
actor,
|
||||||
critic,
|
critic,
|
||||||
optim,
|
optim,
|
||||||
dist_fn=self.dist_fn,
|
dist_fn=self.config.dist_fn,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
# env-stuff
|
# env-stuff
|
||||||
action_space=envs.get_action_space(),
|
action_space=envs.get_action_space(),
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
@ -9,6 +10,17 @@ from tianshou.env import BaseVectorEnv
|
|||||||
TShape = int | Sequence[int]
|
TShape = int | Sequence[int]
|
||||||
|
|
||||||
|
|
||||||
|
class EnvType(Enum):
|
||||||
|
CONTINUOUS = "continuous"
|
||||||
|
DISCRETE = "discrete"
|
||||||
|
|
||||||
|
def is_discrete(self):
|
||||||
|
return self == EnvType.DISCRETE
|
||||||
|
|
||||||
|
def is_continuous(self):
|
||||||
|
return self == EnvType.CONTINUOUS
|
||||||
|
|
||||||
|
|
||||||
class Environments(ABC):
|
class Environments(ABC):
|
||||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
self.env = env
|
self.env = env
|
||||||
@ -29,6 +41,10 @@ class Environments(ABC):
|
|||||||
def get_action_space(self) -> gym.Space:
|
def get_action_space(self) -> gym.Space:
|
||||||
return self.env.action_space
|
return self.env.action_space
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_type(self) -> EnvType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ContinuousEnvironments(Environments):
|
class ContinuousEnvironments(Environments):
|
||||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
@ -62,6 +78,9 @@ class ContinuousEnvironments(Environments):
|
|||||||
def get_state_shape(self) -> TShape:
|
def get_state_shape(self) -> TShape:
|
||||||
return self.state_shape
|
return self.state_shape
|
||||||
|
|
||||||
|
def get_type(self):
|
||||||
|
return EnvType.CONTINUOUS
|
||||||
|
|
||||||
|
|
||||||
class EnvFactory(ABC):
|
class EnvFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
@ -6,9 +8,17 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.highlevel.agent import AgentFactory
|
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, PPOConfig, SACConfig
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory
|
from tianshou.highlevel.env import EnvFactory
|
||||||
from tianshou.highlevel.logger import LoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||||
|
from tianshou.highlevel.module import (
|
||||||
|
ActorFactory,
|
||||||
|
CriticFactory,
|
||||||
|
DefaultActorFactory,
|
||||||
|
DefaultCriticFactory,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
|
|
||||||
@ -38,13 +48,15 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
|||||||
self,
|
self,
|
||||||
config: RLExperimentConfig,
|
config: RLExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
logger_factory: LoggerFactory,
|
|
||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
|
logger_factory: LoggerFactory | None = None,
|
||||||
):
|
):
|
||||||
|
if logger_factory is None:
|
||||||
|
logger_factory = DefaultLoggerFactory()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.env_factory = env_factory
|
self.env_factory = env_factory
|
||||||
self.logger_factory = logger_factory
|
|
||||||
self.agent_factory = agent_factory
|
self.agent_factory = agent_factory
|
||||||
|
self.logger_factory = logger_factory
|
||||||
|
|
||||||
def _set_seed(self) -> None:
|
def _set_seed(self) -> None:
|
||||||
seed = self.config.seed
|
seed = self.config.seed
|
||||||
@ -109,3 +121,214 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
|||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||||
|
|
||||||
|
|
||||||
|
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
||||||
|
|
||||||
|
|
||||||
|
class RLExperimentBuilder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
):
|
||||||
|
self._config = experiment_config
|
||||||
|
self._env_factory = env_factory
|
||||||
|
self._sampling_config = sampling_config
|
||||||
|
self._logger_factory: LoggerFactory | None = None
|
||||||
|
self._optim_factory: OptimizerFactory | None = None
|
||||||
|
|
||||||
|
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
|
||||||
|
self._logger_factory = logger_factory
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
|
||||||
|
self._optim_factory = optim_factory
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_optim_factory_default(
|
||||||
|
self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0,
|
||||||
|
) -> TBuilder:
|
||||||
|
"""Configures the use of the default optimizer, Adam, with the given parameters.
|
||||||
|
|
||||||
|
:param betas: coefficients used for computing running averages of gradient and its square
|
||||||
|
:param eps: term added to the denominator to improve numerical stability
|
||||||
|
:param weight_decay: weight decay (L2 penalty)
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
|
self._optim_factory = AdamOptimizerFactory(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_optim_factory(self) -> OptimizerFactory:
|
||||||
|
if self._optim_factory is None:
|
||||||
|
return AdamOptimizerFactory()
|
||||||
|
else:
|
||||||
|
return self._optim_factory
|
||||||
|
|
||||||
|
def build(self) -> RLExperiment:
|
||||||
|
return RLExperiment(
|
||||||
|
self._config, self._env_factory, self._create_agent_factory(), self._logger_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _BuilderMixinActorFactory:
|
||||||
|
def __init__(self):
|
||||||
|
self._actor_factory: ActorFactory | None = None
|
||||||
|
|
||||||
|
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
|
||||||
|
self: TBuilder | _BuilderMixinActorFactory
|
||||||
|
self._actor_factory = actor_factory
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_actor_factory_default(
|
||||||
|
self: TBuilder,
|
||||||
|
hidden_sizes: Sequence[int],
|
||||||
|
continuous_unbounded=False,
|
||||||
|
continuous_conditioned_sigma=False,
|
||||||
|
) -> TBuilder:
|
||||||
|
self: TBuilder | _BuilderMixinActorFactory
|
||||||
|
self._actor_factory = DefaultActorFactory(
|
||||||
|
hidden_sizes,
|
||||||
|
continuous_unbounded=continuous_unbounded,
|
||||||
|
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _get_actor_factory(self):
|
||||||
|
if self._actor_factory is None:
|
||||||
|
return DefaultActorFactory()
|
||||||
|
else:
|
||||||
|
return self._actor_factory
|
||||||
|
|
||||||
|
|
||||||
|
class _BuilderMixinCriticsFactory:
|
||||||
|
def __init__(self, num_critics: int):
|
||||||
|
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
||||||
|
|
||||||
|
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory):
|
||||||
|
self._critic_factories[idx] = critic_factory
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
|
||||||
|
self._critic_factories[idx] = DefaultCriticFactory(hidden_sizes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _get_critic_factory(self, idx: int):
|
||||||
|
factory = self._critic_factories[idx]
|
||||||
|
if factory is None:
|
||||||
|
return DefaultCriticFactory()
|
||||||
|
else:
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(1)
|
||||||
|
|
||||||
|
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||||
|
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||||
|
self._with_critic_factory(0, critic_factory)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_critic_factory_default(
|
||||||
|
self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||||
|
) -> TBuilder:
|
||||||
|
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||||
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(2)
|
||||||
|
|
||||||
|
def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||||
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||||
|
for i in range(len(self._critic_factories)):
|
||||||
|
self._with_critic_factory(i, critic_factory)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_common_critic_factory_default(
|
||||||
|
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||||
|
) -> TBuilder:
|
||||||
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||||
|
for i in range(len(self._critic_factories)):
|
||||||
|
self._with_critic_factory_default(i, hidden_sizes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||||
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||||
|
self._with_critic_factory(0, critic_factory)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_critic1_factory_default(
|
||||||
|
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||||
|
) -> TBuilder:
|
||||||
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||||
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||||
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||||
|
self._with_critic_factory(1, critic_factory)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_critic2_factory_default(
|
||||||
|
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||||
|
) -> TBuilder:
|
||||||
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||||
|
self._with_critic_factory_default(0, hidden_sizes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class PPOExperimentBuilder(
|
||||||
|
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
):
|
||||||
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
|
_BuilderMixinActorFactory.__init__(self)
|
||||||
|
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||||
|
self._params: PPOConfig = PPOConfig()
|
||||||
|
|
||||||
|
def with_ppo_params(self, params: PPOConfig) -> "PPOExperimentBuilder":
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
return PPOAgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_actor_factory(),
|
||||||
|
self._get_critic_factory(0),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SACExperimentBuilder(
|
||||||
|
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
):
|
||||||
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
|
_BuilderMixinActorFactory.__init__(self)
|
||||||
|
_BuilderMixinDualCriticFactory.__init__(self)
|
||||||
|
self._params: SACConfig = SACConfig()
|
||||||
|
|
||||||
|
def with_sac_params(self, params: SACConfig) -> "SACExperimentBuilder":
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.net.continuous import ActorProb
|
from tianshou.utils.net.continuous import ActorProb
|
||||||
from tianshou.utils.net.continuous import Critic as ContinuousCritic
|
from tianshou.utils.net.continuous import Critic as ContinuousCritic
|
||||||
@ -46,8 +46,41 @@ class ActorFactory(ABC):
|
|||||||
m.weight.data.copy_(0.01 * m.weight.data)
|
m.weight.data.copy_(0.01 * m.weight.data)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultActorFactory(ActorFactory):
|
||||||
|
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||||
|
continuous_unbounded=False,
|
||||||
|
continuous_conditioned_sigma=False,
|
||||||
|
):
|
||||||
|
self.continuous_unbounded = continuous_unbounded
|
||||||
|
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
"""
|
||||||
|
An actor factory which, depending on the type of environment, creates a suitable MLP-based policy
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||||
|
env_type = envs.get_type()
|
||||||
|
if env_type == EnvType.CONTINUOUS:
|
||||||
|
factory = ContinuousActorProbFactory(
|
||||||
|
self.hidden_sizes,
|
||||||
|
unbounded=self.continuous_unbounded,
|
||||||
|
conditioned_sigma=self.continuous_conditioned_sigma,
|
||||||
|
)
|
||||||
|
return factory.create_module(envs, device)
|
||||||
|
elif env_type == EnvType.DISCRETE:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{env_type} not supported")
|
||||||
|
|
||||||
|
|
||||||
class ContinuousActorFactory(ActorFactory, ABC):
|
class ContinuousActorFactory(ActorFactory, ABC):
|
||||||
pass
|
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ContinuousActorProbFactory(ContinuousActorFactory):
|
class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||||
@ -85,13 +118,31 @@ class CriticFactory(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultCriticFactory(CriticFactory):
|
||||||
|
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||||
|
|
||||||
|
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||||
|
|
||||||
|
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
|
env_type = envs.get_type()
|
||||||
|
if env_type == EnvType.CONTINUOUS:
|
||||||
|
factory = ContinuousNetCriticFactory(self.hidden_sizes)
|
||||||
|
return factory.create_module(envs, device, use_action)
|
||||||
|
elif env_type == EnvType.DISCRETE:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{env_type} not supported")
|
||||||
|
|
||||||
|
|
||||||
class ContinuousCriticFactory(CriticFactory, ABC):
|
class ContinuousCriticFactory(CriticFactory, ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
||||||
def __init__(self, hidden_sizes: Sequence[int], action_shape=0):
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
self.action_shape = action_shape
|
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
|
@ -29,8 +29,19 @@ class TorchOptimizerFactory(OptimizerFactory):
|
|||||||
|
|
||||||
|
|
||||||
class AdamOptimizerFactory(OptimizerFactory):
|
class AdamOptimizerFactory(OptimizerFactory):
|
||||||
|
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.eps = eps
|
||||||
|
self.betas = betas
|
||||||
|
|
||||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
|
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
|
||||||
return Adam(module.parameters(), lr=lr)
|
return Adam(
|
||||||
|
module.parameters(),
|
||||||
|
lr=lr,
|
||||||
|
betas=self.betas,
|
||||||
|
eps=self.eps,
|
||||||
|
weight_decay=self.weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LRSchedulerFactory(ABC):
|
class LRSchedulerFactory(ABC):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user