Add high-level experiment builder interface

This commit is contained in:
Dominik Jain 2023-09-21 12:36:27 +02:00
parent 4d53d345d6
commit 37dc07e487
7 changed files with 369 additions and 91 deletions

View File

@ -9,18 +9,13 @@ from jsonargparse import CLI
from torch.distributions import Independent, Normal from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import PPOAgentFactory, PPOConfig from tianshou.highlevel.agent import PPOConfig
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
RLExperiment, PPOExperimentBuilder,
RLExperimentConfig, RLExperimentConfig,
) )
from tianshou.highlevel.logger import DefaultLoggerFactory from tianshou.highlevel.optim import LinearLRSchedulerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
def main( def main(
@ -52,7 +47,6 @@ def main(
): ):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory()
sampling_config = RLSamplingConfig( sampling_config = RLSamplingConfig(
num_epochs=epoch, num_epochs=epoch,
@ -70,7 +64,10 @@ def main(
def dist_fn(*logits): def dist_fn(*logits):
return Independent(Normal(*logits), 1) return Independent(Normal(*logits), 1)
ppo_config = PPOConfig( experiment = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_ppo_params(
PPOConfig(
gamma=gamma, gamma=gamma,
gae_lambda=gae_lambda, gae_lambda=gae_lambda,
action_bound_method=bound_action_method, action_bound_method=bound_action_method,
@ -83,24 +80,17 @@ def main(
eps_clip=eps_clip, eps_clip=eps_clip,
dual_clip=dual_clip, dual_clip=dual_clip,
recompute_adv=recompute_adv, recompute_adv=recompute_adv,
dist_fn=dist_fn,
lr=lr,
lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config)
if lr_decay
else None,
),
) )
actor_factory = ContinuousActorProbFactory(hidden_sizes) .with_actor_factory_default(hidden_sizes)
critic_factory = ContinuousNetCriticFactory(hidden_sizes) .with_critic_factory_default(hidden_sizes)
optim_factory = AdamOptimizerFactory() .build()
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if lr_decay else None
agent_factory = PPOAgentFactory(
ppo_config,
sampling_config,
actor_factory,
critic_factory,
optim_factory,
dist_fn,
lr,
lr_scheduler_factory,
) )
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
experiment.run(log_name) experiment.run(log_name)

View File

@ -7,18 +7,12 @@ from collections.abc import Sequence
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACConfig
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
RLExperiment,
RLExperimentConfig, RLExperimentConfig,
SACExperimentBuilder,
) )
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory
def main( def main(
@ -45,7 +39,6 @@ def main(
): ):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "sac", str(experiment_config.seed), now) log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory()
sampling_config = RLSamplingConfig( sampling_config = RLSamplingConfig(
num_epochs=epoch, num_epochs=epoch,
@ -62,31 +55,25 @@ def main(
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
if auto_alpha: experiment = (
alpha = DefaultAutoAlphaFactory(lr=alpha_lr) SACExperimentBuilder(experiment_config, env_factory, sampling_config)
sac_config = SACConfig( .with_sac_params(
SACConfig(
tau=tau, tau=tau,
gamma=gamma, gamma=gamma,
alpha=alpha, alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha,
estimation_step=n_step, estimation_step=n_step,
actor_lr=actor_lr, actor_lr=actor_lr,
critic1_lr=critic_lr, critic1_lr=critic_lr,
critic2_lr=critic_lr, critic2_lr=critic_lr,
),
) )
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True) .with_actor_factory_default(
critic_factory = ContinuousNetCriticFactory(hidden_sizes) hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True,
optim_factory = AdamOptimizerFactory() )
agent_factory = SACAgentFactory( .with_common_critic_factory_default(hidden_sizes)
sac_config, .build()
sampling_config,
actor_factory,
critic_factory,
critic_factory,
optim_factory,
) )
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
experiment.run(log_name) experiment.run(log_name)

View File

@ -161,6 +161,9 @@ class PPOConfig(PGConfig):
eps_clip: float = 0.2 eps_clip: float = 0.2
dual_clip: float | None = None dual_clip: float | None = None
recompute_adv: bool = True recompute_adv: bool = True
dist_fn: Callable = None
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
class PPOAgentFactory(OnpolicyAgentFactory): class PPOAgentFactory(OnpolicyAgentFactory):
@ -171,26 +174,20 @@ class PPOAgentFactory(OnpolicyAgentFactory):
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
dist_fn,
lr: float,
lr_scheduler_factory: LRSchedulerFactory | None = None,
): ):
super().__init__(sampling_config) super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory self.critic_factory = critic_factory
self.actor_factory = actor_factory self.actor_factory = actor_factory
self.config = config self.config = config
self.lr = lr
self.lr_scheduler_factory = lr_scheduler_factory
self.dist_fn = dist_fn
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device) actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=False) critic = self.critic_factory.create_module(envs, device, use_action=False)
actor_critic = ActorCritic(actor, critic) actor_critic = ActorCritic(actor, critic)
optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr) optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
if self.lr_scheduler_factory is not None: if self.config.lr_scheduler_factory is not None:
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim) lr_scheduler = self.config.lr_scheduler_factory.create_scheduler(optim)
else: else:
lr_scheduler = None lr_scheduler = None
return PPOPolicy( return PPOPolicy(
@ -198,7 +195,7 @@ class PPOAgentFactory(OnpolicyAgentFactory):
actor, actor,
critic, critic,
optim, optim,
dist_fn=self.dist_fn, dist_fn=self.config.dist_fn,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
# env-stuff # env-stuff
action_space=envs.get_action_space(), action_space=envs.get_action_space(),

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum
from typing import Any from typing import Any
import gymnasium as gym import gymnasium as gym
@ -9,6 +10,17 @@ from tianshou.env import BaseVectorEnv
TShape = int | Sequence[int] TShape = int | Sequence[int]
class EnvType(Enum):
CONTINUOUS = "continuous"
DISCRETE = "discrete"
def is_discrete(self):
return self == EnvType.DISCRETE
def is_continuous(self):
return self == EnvType.CONTINUOUS
class Environments(ABC): class Environments(ABC):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env self.env = env
@ -29,6 +41,10 @@ class Environments(ABC):
def get_action_space(self) -> gym.Space: def get_action_space(self) -> gym.Space:
return self.env.action_space return self.env.action_space
@abstractmethod
def get_type(self) -> EnvType:
pass
class ContinuousEnvironments(Environments): class ContinuousEnvironments(Environments):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
@ -62,6 +78,9 @@ class ContinuousEnvironments(Environments):
def get_state_shape(self) -> TShape: def get_state_shape(self) -> TShape:
return self.state_shape return self.state_shape
def get_type(self):
return EnvType.CONTINUOUS
class EnvFactory(ABC): class EnvFactory(ABC):
@abstractmethod @abstractmethod

View File

@ -1,3 +1,5 @@
from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from pprint import pprint from pprint import pprint
from typing import Generic, TypeVar from typing import Generic, TypeVar
@ -6,9 +8,17 @@ import numpy as np
import torch import torch
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, PPOConfig, SACConfig
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import LoggerFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module import (
ActorFactory,
CriticFactory,
DefaultActorFactory,
DefaultCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
@ -38,13 +48,15 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
self, self,
config: RLExperimentConfig, config: RLExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
logger_factory: LoggerFactory,
agent_factory: AgentFactory, agent_factory: AgentFactory,
logger_factory: LoggerFactory | None = None,
): ):
if logger_factory is None:
logger_factory = DefaultLoggerFactory()
self.config = config self.config = config
self.env_factory = env_factory self.env_factory = env_factory
self.logger_factory = logger_factory
self.agent_factory = agent_factory self.agent_factory = agent_factory
self.logger_factory = logger_factory
def _set_seed(self) -> None: def _set_seed(self) -> None:
seed = self.config.seed seed = self.config.seed
@ -109,3 +121,214 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render) result = test_collector.collect(n_episode=num_episodes, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
class RLExperimentBuilder:
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
):
self._config = experiment_config
self._env_factory = env_factory
self._sampling_config = sampling_config
self._logger_factory: LoggerFactory | None = None
self._optim_factory: OptimizerFactory | None = None
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
self._logger_factory = logger_factory
return self
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
self._optim_factory = optim_factory
return self
def with_optim_factory_default(
self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0,
) -> TBuilder:
"""Configures the use of the default optimizer, Adam, with the given parameters.
:param betas: coefficients used for computing running averages of gradient and its square
:param eps: term added to the denominator to improve numerical stability
:param weight_decay: weight decay (L2 penalty)
:return: the builder
"""
self._optim_factory = AdamOptimizerFactory(betas=betas, eps=eps, weight_decay=weight_decay)
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
pass
def _get_optim_factory(self) -> OptimizerFactory:
if self._optim_factory is None:
return AdamOptimizerFactory()
else:
return self._optim_factory
def build(self) -> RLExperiment:
return RLExperiment(
self._config, self._env_factory, self._create_agent_factory(), self._logger_factory,
)
class _BuilderMixinActorFactory:
def __init__(self):
self._actor_factory: ActorFactory | None = None
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
self: TBuilder | _BuilderMixinActorFactory
self._actor_factory = actor_factory
return self
def with_actor_factory_default(
self: TBuilder,
hidden_sizes: Sequence[int],
continuous_unbounded=False,
continuous_conditioned_sigma=False,
) -> TBuilder:
self: TBuilder | _BuilderMixinActorFactory
self._actor_factory = DefaultActorFactory(
hidden_sizes,
continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma,
)
return self
def _get_actor_factory(self):
if self._actor_factory is None:
return DefaultActorFactory()
else:
return self._actor_factory
class _BuilderMixinCriticsFactory:
def __init__(self, num_critics: int):
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory):
self._critic_factories[idx] = critic_factory
return self
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
self._critic_factories[idx] = DefaultCriticFactory(hidden_sizes)
return self
def _get_critic_factory(self, idx: int):
factory = self._critic_factories[idx]
if factory is None:
return DefaultCriticFactory()
else:
return factory
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self):
super().__init__(1)
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinSingleCriticFactory"
self._with_critic_factory(0, critic_factory)
return self
def with_critic_factory_default(
self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinSingleCriticFactory"
self._with_critic_factory_default(0, hidden_sizes)
return self
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self):
super().__init__(2)
def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
for i in range(len(self._critic_factories)):
self._with_critic_factory(i, critic_factory)
return self
def with_common_critic_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
for i in range(len(self._critic_factories)):
self._with_critic_factory_default(i, hidden_sizes)
return self
def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory(0, critic_factory)
return self
def with_critic1_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes)
return self
def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory(1, critic_factory)
return self
def with_critic2_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes)
return self
class PPOExperimentBuilder(
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOConfig = PPOConfig()
def with_ppo_params(self, params: PPOConfig) -> "PPOExperimentBuilder":
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return PPOAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
)
class SACExperimentBuilder(
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
self._params: SACConfig = SACConfig()
def with_sac_params(self, params: SACConfig) -> "SACExperimentBuilder":
self._params = params
return self

View File

@ -5,7 +5,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments, EnvType
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.continuous import Critic as ContinuousCritic from tianshou.utils.net.continuous import Critic as ContinuousCritic
@ -46,8 +46,41 @@ class ActorFactory(ABC):
m.weight.data.copy_(0.01 * m.weight.data) m.weight.data.copy_(0.01 * m.weight.data)
class DefaultActorFactory(ActorFactory):
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(
self,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded=False,
continuous_conditioned_sigma=False,
):
self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes
"""
An actor factory which, depending on the type of environment, creates a suitable MLP-based policy
"""
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = ContinuousActorProbFactory(
self.hidden_sizes,
unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma,
)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class ContinuousActorFactory(ActorFactory, ABC): class ContinuousActorFactory(ActorFactory, ABC):
pass """Serves as a type bound for actor factories that are suitable for continuous action spaces."""
class ContinuousActorProbFactory(ContinuousActorFactory): class ContinuousActorProbFactory(ContinuousActorFactory):
@ -85,13 +118,31 @@ class CriticFactory(ABC):
pass pass
class DefaultCriticFactory(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = ContinuousNetCriticFactory(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class ContinuousCriticFactory(CriticFactory, ABC): class ContinuousCriticFactory(CriticFactory, ABC):
pass pass
class ContinuousNetCriticFactory(ContinuousCriticFactory): class ContinuousNetCriticFactory(ContinuousCriticFactory):
def __init__(self, hidden_sizes: Sequence[int], action_shape=0): def __init__(self, hidden_sizes: Sequence[int]):
self.action_shape = action_shape
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:

View File

@ -29,8 +29,19 @@ class TorchOptimizerFactory(OptimizerFactory):
class AdamOptimizerFactory(OptimizerFactory): class AdamOptimizerFactory(OptimizerFactory):
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
self.weight_decay = weight_decay
self.eps = eps
self.betas = betas
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam: def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
return Adam(module.parameters(), lr=lr) return Adam(
module.parameters(),
lr=lr,
betas=self.betas,
eps=self.eps,
weight_decay=self.weight_decay,
)
class LRSchedulerFactory(ABC): class LRSchedulerFactory(ABC):