Add high-level experiment builder interface
This commit is contained in:
parent
4d53d345d6
commit
37dc07e487
@ -9,18 +9,13 @@ from jsonargparse import CLI
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.agent import PPOAgentFactory, PPOConfig
|
||||
from tianshou.highlevel.agent import PPOConfig
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
RLExperiment,
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ContinuousActorProbFactory,
|
||||
ContinuousNetCriticFactory,
|
||||
)
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
||||
from tianshou.highlevel.optim import LinearLRSchedulerFactory
|
||||
|
||||
|
||||
def main(
|
||||
@ -52,7 +47,6 @@ def main(
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
logger_factory = DefaultLoggerFactory()
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
num_epochs=epoch,
|
||||
@ -70,37 +64,33 @@ def main(
|
||||
def dist_fn(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
ppo_config = PPOConfig(
|
||||
gamma=gamma,
|
||||
gae_lambda=gae_lambda,
|
||||
action_bound_method=bound_action_method,
|
||||
rew_norm=rew_norm,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm,
|
||||
value_clip=value_clip,
|
||||
norm_adv=norm_adv,
|
||||
eps_clip=eps_clip,
|
||||
dual_clip=dual_clip,
|
||||
recompute_adv=recompute_adv,
|
||||
experiment = (
|
||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
.with_ppo_params(
|
||||
PPOConfig(
|
||||
gamma=gamma,
|
||||
gae_lambda=gae_lambda,
|
||||
action_bound_method=bound_action_method,
|
||||
rew_norm=rew_norm,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm,
|
||||
value_clip=value_clip,
|
||||
norm_adv=norm_adv,
|
||||
eps_clip=eps_clip,
|
||||
dual_clip=dual_clip,
|
||||
recompute_adv=recompute_adv,
|
||||
dist_fn=dist_fn,
|
||||
lr=lr,
|
||||
lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config)
|
||||
if lr_decay
|
||||
else None,
|
||||
),
|
||||
)
|
||||
.with_actor_factory_default(hidden_sizes)
|
||||
.with_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
actor_factory = ContinuousActorProbFactory(hidden_sizes)
|
||||
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
||||
optim_factory = AdamOptimizerFactory()
|
||||
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if lr_decay else None
|
||||
agent_factory = PPOAgentFactory(
|
||||
ppo_config,
|
||||
sampling_config,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optim_factory,
|
||||
dist_fn,
|
||||
lr,
|
||||
lr_scheduler_factory,
|
||||
)
|
||||
|
||||
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
|
||||
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
|
@ -7,18 +7,12 @@ from collections.abc import Sequence
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig
|
||||
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACConfig
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
RLExperiment,
|
||||
RLExperimentConfig,
|
||||
SACExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ContinuousActorProbFactory,
|
||||
ContinuousNetCriticFactory,
|
||||
)
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory
|
||||
|
||||
|
||||
def main(
|
||||
@ -45,7 +39,6 @@ def main(
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
||||
logger_factory = DefaultLoggerFactory()
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
num_epochs=epoch,
|
||||
@ -62,31 +55,25 @@ def main(
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
|
||||
if auto_alpha:
|
||||
alpha = DefaultAutoAlphaFactory(lr=alpha_lr)
|
||||
sac_config = SACConfig(
|
||||
tau=tau,
|
||||
gamma=gamma,
|
||||
alpha=alpha,
|
||||
estimation_step=n_step,
|
||||
actor_lr=actor_lr,
|
||||
critic1_lr=critic_lr,
|
||||
critic2_lr=critic_lr,
|
||||
experiment = (
|
||||
SACExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||
.with_sac_params(
|
||||
SACConfig(
|
||||
tau=tau,
|
||||
gamma=gamma,
|
||||
alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha,
|
||||
estimation_step=n_step,
|
||||
actor_lr=actor_lr,
|
||||
critic1_lr=critic_lr,
|
||||
critic2_lr=critic_lr,
|
||||
),
|
||||
)
|
||||
.with_actor_factory_default(
|
||||
hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True,
|
||||
)
|
||||
.with_common_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
|
||||
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
||||
optim_factory = AdamOptimizerFactory()
|
||||
agent_factory = SACAgentFactory(
|
||||
sac_config,
|
||||
sampling_config,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
critic_factory,
|
||||
optim_factory,
|
||||
)
|
||||
|
||||
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
|
||||
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
|
@ -161,6 +161,9 @@ class PPOConfig(PGConfig):
|
||||
eps_clip: float = 0.2
|
||||
dual_clip: float | None = None
|
||||
recompute_adv: bool = True
|
||||
dist_fn: Callable = None
|
||||
lr: float = 1e-3
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||
|
||||
|
||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
@ -171,26 +174,20 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
dist_fn,
|
||||
lr: float,
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None,
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
self.optimizer_factory = optimizer_factory
|
||||
self.critic_factory = critic_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.config = config
|
||||
self.lr = lr
|
||||
self.lr_scheduler_factory = lr_scheduler_factory
|
||||
self.dist_fn = dist_fn
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=False)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr)
|
||||
if self.lr_scheduler_factory is not None:
|
||||
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
|
||||
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
|
||||
if self.config.lr_scheduler_factory is not None:
|
||||
lr_scheduler = self.config.lr_scheduler_factory.create_scheduler(optim)
|
||||
else:
|
||||
lr_scheduler = None
|
||||
return PPOPolicy(
|
||||
@ -198,7 +195,7 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
actor,
|
||||
critic,
|
||||
optim,
|
||||
dist_fn=self.dist_fn,
|
||||
dist_fn=self.config.dist_fn,
|
||||
lr_scheduler=lr_scheduler,
|
||||
# env-stuff
|
||||
action_space=envs.get_action_space(),
|
||||
|
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
@ -9,6 +10,17 @@ from tianshou.env import BaseVectorEnv
|
||||
TShape = int | Sequence[int]
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
CONTINUOUS = "continuous"
|
||||
DISCRETE = "discrete"
|
||||
|
||||
def is_discrete(self):
|
||||
return self == EnvType.DISCRETE
|
||||
|
||||
def is_continuous(self):
|
||||
return self == EnvType.CONTINUOUS
|
||||
|
||||
|
||||
class Environments(ABC):
|
||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
self.env = env
|
||||
@ -29,6 +41,10 @@ class Environments(ABC):
|
||||
def get_action_space(self) -> gym.Space:
|
||||
return self.env.action_space
|
||||
|
||||
@abstractmethod
|
||||
def get_type(self) -> EnvType:
|
||||
pass
|
||||
|
||||
|
||||
class ContinuousEnvironments(Environments):
|
||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
@ -62,6 +78,9 @@ class ContinuousEnvironments(Environments):
|
||||
def get_state_shape(self) -> TShape:
|
||||
return self.state_shape
|
||||
|
||||
def get_type(self):
|
||||
return EnvType.CONTINUOUS
|
||||
|
||||
|
||||
class EnvFactory(ABC):
|
||||
@abstractmethod
|
||||
|
@ -1,3 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pprint import pprint
|
||||
from typing import Generic, TypeVar
|
||||
@ -6,9 +8,17 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.highlevel.agent import AgentFactory
|
||||
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, PPOConfig, SACConfig
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
from tianshou.highlevel.logger import LoggerFactory
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ActorFactory,
|
||||
CriticFactory,
|
||||
DefaultActorFactory,
|
||||
DefaultCriticFactory,
|
||||
)
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import BaseTrainer
|
||||
|
||||
@ -38,13 +48,15 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
self,
|
||||
config: RLExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
logger_factory: LoggerFactory,
|
||||
agent_factory: AgentFactory,
|
||||
logger_factory: LoggerFactory | None = None,
|
||||
):
|
||||
if logger_factory is None:
|
||||
logger_factory = DefaultLoggerFactory()
|
||||
self.config = config
|
||||
self.env_factory = env_factory
|
||||
self.logger_factory = logger_factory
|
||||
self.agent_factory = agent_factory
|
||||
self.logger_factory = logger_factory
|
||||
|
||||
def _set_seed(self) -> None:
|
||||
seed = self.config.seed
|
||||
@ -109,3 +121,214 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
||||
|
||||
|
||||
class RLExperimentBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
):
|
||||
self._config = experiment_config
|
||||
self._env_factory = env_factory
|
||||
self._sampling_config = sampling_config
|
||||
self._logger_factory: LoggerFactory | None = None
|
||||
self._optim_factory: OptimizerFactory | None = None
|
||||
|
||||
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
|
||||
self._logger_factory = logger_factory
|
||||
return self
|
||||
|
||||
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
|
||||
self._optim_factory = optim_factory
|
||||
return self
|
||||
|
||||
def with_optim_factory_default(
|
||||
self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0,
|
||||
) -> TBuilder:
|
||||
"""Configures the use of the default optimizer, Adam, with the given parameters.
|
||||
|
||||
:param betas: coefficients used for computing running averages of gradient and its square
|
||||
:param eps: term added to the denominator to improve numerical stability
|
||||
:param weight_decay: weight decay (L2 penalty)
|
||||
:return: the builder
|
||||
"""
|
||||
self._optim_factory = AdamOptimizerFactory(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
pass
|
||||
|
||||
def _get_optim_factory(self) -> OptimizerFactory:
|
||||
if self._optim_factory is None:
|
||||
return AdamOptimizerFactory()
|
||||
else:
|
||||
return self._optim_factory
|
||||
|
||||
def build(self) -> RLExperiment:
|
||||
return RLExperiment(
|
||||
self._config, self._env_factory, self._create_agent_factory(), self._logger_factory,
|
||||
)
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory:
|
||||
def __init__(self):
|
||||
self._actor_factory: ActorFactory | None = None
|
||||
|
||||
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
|
||||
self: TBuilder | _BuilderMixinActorFactory
|
||||
self._actor_factory = actor_factory
|
||||
return self
|
||||
|
||||
def with_actor_factory_default(
|
||||
self: TBuilder,
|
||||
hidden_sizes: Sequence[int],
|
||||
continuous_unbounded=False,
|
||||
continuous_conditioned_sigma=False,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | _BuilderMixinActorFactory
|
||||
self._actor_factory = DefaultActorFactory(
|
||||
hidden_sizes,
|
||||
continuous_unbounded=continuous_unbounded,
|
||||
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
||||
)
|
||||
return self
|
||||
|
||||
def _get_actor_factory(self):
|
||||
if self._actor_factory is None:
|
||||
return DefaultActorFactory()
|
||||
else:
|
||||
return self._actor_factory
|
||||
|
||||
|
||||
class _BuilderMixinCriticsFactory:
|
||||
def __init__(self, num_critics: int):
|
||||
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
||||
|
||||
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory):
|
||||
self._critic_factories[idx] = critic_factory
|
||||
return self
|
||||
|
||||
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
|
||||
self._critic_factories[idx] = DefaultCriticFactory(hidden_sizes)
|
||||
return self
|
||||
|
||||
def _get_critic_factory(self, idx: int):
|
||||
factory = self._critic_factories[idx]
|
||||
if factory is None:
|
||||
return DefaultCriticFactory()
|
||||
else:
|
||||
return factory
|
||||
|
||||
|
||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self):
|
||||
super().__init__(1)
|
||||
|
||||
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||
self._with_critic_factory(0, critic_factory)
|
||||
return self
|
||||
|
||||
def with_critic_factory_default(
|
||||
self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinSingleCriticFactory"
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
|
||||
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self):
|
||||
super().__init__(2)
|
||||
|
||||
def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
for i in range(len(self._critic_factories)):
|
||||
self._with_critic_factory(i, critic_factory)
|
||||
return self
|
||||
|
||||
def with_common_critic_factory_default(
|
||||
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
for i in range(len(self._critic_factories)):
|
||||
self._with_critic_factory_default(i, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
self._with_critic_factory(0, critic_factory)
|
||||
return self
|
||||
|
||||
def with_critic1_factory_default(
|
||||
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
self._with_critic_factory(1, critic_factory)
|
||||
return self
|
||||
|
||||
def with_critic2_factory_default(
|
||||
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
|
||||
) -> TBuilder:
|
||||
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
||||
self._with_critic_factory_default(0, hidden_sizes)
|
||||
return self
|
||||
|
||||
|
||||
class PPOExperimentBuilder(
|
||||
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self)
|
||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||
self._params: PPOConfig = PPOConfig()
|
||||
|
||||
def with_ppo_params(self, params: PPOConfig) -> "PPOExperimentBuilder":
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return PPOAgentFactory(
|
||||
self._params,
|
||||
self._sampling_config,
|
||||
self._get_actor_factory(),
|
||||
self._get_critic_factory(0),
|
||||
self._get_optim_factory(),
|
||||
)
|
||||
|
||||
|
||||
class SACExperimentBuilder(
|
||||
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self)
|
||||
_BuilderMixinDualCriticFactory.__init__(self)
|
||||
self._params: SACConfig = SACConfig()
|
||||
|
||||
def with_sac_params(self, params: SACConfig) -> "SACExperimentBuilder":
|
||||
self._params = params
|
||||
return self
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
from tianshou.utils.net.continuous import Critic as ContinuousCritic
|
||||
@ -46,8 +46,41 @@ class ActorFactory(ABC):
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
|
||||
class DefaultActorFactory(ActorFactory):
|
||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
||||
continuous_unbounded=False,
|
||||
continuous_conditioned_sigma=False,
|
||||
):
|
||||
self.continuous_unbounded = continuous_unbounded
|
||||
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
"""
|
||||
An actor factory which, depending on the type of environment, creates a suitable MLP-based policy
|
||||
"""
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
env_type = envs.get_type()
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
factory = ContinuousActorProbFactory(
|
||||
self.hidden_sizes,
|
||||
unbounded=self.continuous_unbounded,
|
||||
conditioned_sigma=self.continuous_conditioned_sigma,
|
||||
)
|
||||
return factory.create_module(envs, device)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
|
||||
|
||||
class ContinuousActorFactory(ActorFactory, ABC):
|
||||
pass
|
||||
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
||||
|
||||
|
||||
|
||||
class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||
@ -85,13 +118,31 @@ class CriticFactory(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class DefaultCriticFactory(CriticFactory):
|
||||
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = (64, 64)
|
||||
|
||||
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
env_type = envs.get_type()
|
||||
if env_type == EnvType.CONTINUOUS:
|
||||
factory = ContinuousNetCriticFactory(self.hidden_sizes)
|
||||
return factory.create_module(envs, device, use_action)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
|
||||
|
||||
class ContinuousCriticFactory(CriticFactory, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int], action_shape=0):
|
||||
self.action_shape = action_shape
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
|
@ -29,8 +29,19 @@ class TorchOptimizerFactory(OptimizerFactory):
|
||||
|
||||
|
||||
class AdamOptimizerFactory(OptimizerFactory):
|
||||
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
|
||||
self.weight_decay = weight_decay
|
||||
self.eps = eps
|
||||
self.betas = betas
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
|
||||
return Adam(module.parameters(), lr=lr)
|
||||
return Adam(
|
||||
module.parameters(),
|
||||
lr=lr,
|
||||
betas=self.betas,
|
||||
eps=self.eps,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
|
||||
class LRSchedulerFactory(ABC):
|
||||
|
Loading…
x
Reference in New Issue
Block a user