Add high-level experiment builder interface

This commit is contained in:
Dominik Jain 2023-09-21 12:36:27 +02:00
parent 4d53d345d6
commit 37dc07e487
7 changed files with 369 additions and 91 deletions

View File

@ -9,18 +9,13 @@ from jsonargparse import CLI
from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import PPOAgentFactory, PPOConfig
from tianshou.highlevel.agent import PPOConfig
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
RLExperiment,
PPOExperimentBuilder,
RLExperimentConfig,
)
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
from tianshou.highlevel.optim import LinearLRSchedulerFactory
def main(
@ -52,7 +47,6 @@ def main(
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory()
sampling_config = RLSamplingConfig(
num_epochs=epoch,
@ -70,37 +64,33 @@ def main(
def dist_fn(*logits):
return Independent(Normal(*logits), 1)
ppo_config = PPOConfig(
gamma=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
rew_norm=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
norm_adv=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_adv=recompute_adv,
experiment = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_ppo_params(
PPOConfig(
gamma=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
rew_norm=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
norm_adv=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_adv=recompute_adv,
dist_fn=dist_fn,
lr=lr,
lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config)
if lr_decay
else None,
),
)
.with_actor_factory_default(hidden_sizes)
.with_critic_factory_default(hidden_sizes)
.build()
)
actor_factory = ContinuousActorProbFactory(hidden_sizes)
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
optim_factory = AdamOptimizerFactory()
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if lr_decay else None
agent_factory = PPOAgentFactory(
ppo_config,
sampling_config,
actor_factory,
critic_factory,
optim_factory,
dist_fn,
lr,
lr_scheduler_factory,
)
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
experiment.run(log_name)

View File

@ -7,18 +7,12 @@ from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACConfig
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
RLExperiment,
RLExperimentConfig,
SACExperimentBuilder,
)
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory
def main(
@ -45,7 +39,6 @@ def main(
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory()
sampling_config = RLSamplingConfig(
num_epochs=epoch,
@ -62,31 +55,25 @@ def main(
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
if auto_alpha:
alpha = DefaultAutoAlphaFactory(lr=alpha_lr)
sac_config = SACConfig(
tau=tau,
gamma=gamma,
alpha=alpha,
estimation_step=n_step,
actor_lr=actor_lr,
critic1_lr=critic_lr,
critic2_lr=critic_lr,
experiment = (
SACExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_sac_params(
SACConfig(
tau=tau,
gamma=gamma,
alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha,
estimation_step=n_step,
actor_lr=actor_lr,
critic1_lr=critic_lr,
critic2_lr=critic_lr,
),
)
.with_actor_factory_default(
hidden_sizes, continuous_unbounded=True, continuous_conditioned_sigma=True,
)
.with_common_critic_factory_default(hidden_sizes)
.build()
)
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
optim_factory = AdamOptimizerFactory()
agent_factory = SACAgentFactory(
sac_config,
sampling_config,
actor_factory,
critic_factory,
critic_factory,
optim_factory,
)
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
experiment.run(log_name)

View File

@ -161,6 +161,9 @@ class PPOConfig(PGConfig):
eps_clip: float = 0.2
dual_clip: float | None = None
recompute_adv: bool = True
dist_fn: Callable = None
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
class PPOAgentFactory(OnpolicyAgentFactory):
@ -171,26 +174,20 @@ class PPOAgentFactory(OnpolicyAgentFactory):
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
dist_fn,
lr: float,
lr_scheduler_factory: LRSchedulerFactory | None = None,
):
super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory
self.actor_factory = actor_factory
self.config = config
self.lr = lr
self.lr_scheduler_factory = lr_scheduler_factory
self.dist_fn = dist_fn
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=False)
actor_critic = ActorCritic(actor, critic)
optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr)
if self.lr_scheduler_factory is not None:
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
if self.config.lr_scheduler_factory is not None:
lr_scheduler = self.config.lr_scheduler_factory.create_scheduler(optim)
else:
lr_scheduler = None
return PPOPolicy(
@ -198,7 +195,7 @@ class PPOAgentFactory(OnpolicyAgentFactory):
actor,
critic,
optim,
dist_fn=self.dist_fn,
dist_fn=self.config.dist_fn,
lr_scheduler=lr_scheduler,
# env-stuff
action_space=envs.get_action_space(),

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import Enum
from typing import Any
import gymnasium as gym
@ -9,6 +10,17 @@ from tianshou.env import BaseVectorEnv
TShape = int | Sequence[int]
class EnvType(Enum):
CONTINUOUS = "continuous"
DISCRETE = "discrete"
def is_discrete(self):
return self == EnvType.DISCRETE
def is_continuous(self):
return self == EnvType.CONTINUOUS
class Environments(ABC):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
@ -29,6 +41,10 @@ class Environments(ABC):
def get_action_space(self) -> gym.Space:
return self.env.action_space
@abstractmethod
def get_type(self) -> EnvType:
pass
class ContinuousEnvironments(Environments):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
@ -62,6 +78,9 @@ class ContinuousEnvironments(Environments):
def get_state_shape(self) -> TShape:
return self.state_shape
def get_type(self):
return EnvType.CONTINUOUS
class EnvFactory(ABC):
@abstractmethod

View File

@ -1,3 +1,5 @@
from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from pprint import pprint
from typing import Generic, TypeVar
@ -6,9 +8,17 @@ import numpy as np
import torch
from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, PPOConfig, SACConfig
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import LoggerFactory
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module import (
ActorFactory,
CriticFactory,
DefaultActorFactory,
DefaultCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
@ -38,13 +48,15 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
self,
config: RLExperimentConfig,
env_factory: EnvFactory,
logger_factory: LoggerFactory,
agent_factory: AgentFactory,
logger_factory: LoggerFactory | None = None,
):
if logger_factory is None:
logger_factory = DefaultLoggerFactory()
self.config = config
self.env_factory = env_factory
self.logger_factory = logger_factory
self.agent_factory = agent_factory
self.logger_factory = logger_factory
def _set_seed(self) -> None:
seed = self.config.seed
@ -109,3 +121,214 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
class RLExperimentBuilder:
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
):
self._config = experiment_config
self._env_factory = env_factory
self._sampling_config = sampling_config
self._logger_factory: LoggerFactory | None = None
self._optim_factory: OptimizerFactory | None = None
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
self._logger_factory = logger_factory
return self
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
self._optim_factory = optim_factory
return self
def with_optim_factory_default(
self: TBuilder, betas=(0.9, 0.999), eps=1e-08, weight_decay=0,
) -> TBuilder:
"""Configures the use of the default optimizer, Adam, with the given parameters.
:param betas: coefficients used for computing running averages of gradient and its square
:param eps: term added to the denominator to improve numerical stability
:param weight_decay: weight decay (L2 penalty)
:return: the builder
"""
self._optim_factory = AdamOptimizerFactory(betas=betas, eps=eps, weight_decay=weight_decay)
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
pass
def _get_optim_factory(self) -> OptimizerFactory:
if self._optim_factory is None:
return AdamOptimizerFactory()
else:
return self._optim_factory
def build(self) -> RLExperiment:
return RLExperiment(
self._config, self._env_factory, self._create_agent_factory(), self._logger_factory,
)
class _BuilderMixinActorFactory:
def __init__(self):
self._actor_factory: ActorFactory | None = None
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
self: TBuilder | _BuilderMixinActorFactory
self._actor_factory = actor_factory
return self
def with_actor_factory_default(
self: TBuilder,
hidden_sizes: Sequence[int],
continuous_unbounded=False,
continuous_conditioned_sigma=False,
) -> TBuilder:
self: TBuilder | _BuilderMixinActorFactory
self._actor_factory = DefaultActorFactory(
hidden_sizes,
continuous_unbounded=continuous_unbounded,
continuous_conditioned_sigma=continuous_conditioned_sigma,
)
return self
def _get_actor_factory(self):
if self._actor_factory is None:
return DefaultActorFactory()
else:
return self._actor_factory
class _BuilderMixinCriticsFactory:
def __init__(self, num_critics: int):
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory):
self._critic_factories[idx] = critic_factory
return self
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
self._critic_factories[idx] = DefaultCriticFactory(hidden_sizes)
return self
def _get_critic_factory(self, idx: int):
factory = self._critic_factories[idx]
if factory is None:
return DefaultCriticFactory()
else:
return factory
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self):
super().__init__(1)
def with_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinSingleCriticFactory"
self._with_critic_factory(0, critic_factory)
return self
def with_critic_factory_default(
self: TBuilder, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinSingleCriticFactory"
self._with_critic_factory_default(0, hidden_sizes)
return self
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self):
super().__init__(2)
def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
for i in range(len(self._critic_factories)):
self._with_critic_factory(i, critic_factory)
return self
def with_common_critic_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
for i in range(len(self._critic_factories)):
self._with_critic_factory_default(i, hidden_sizes)
return self
def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory(0, critic_factory)
return self
def with_critic1_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes)
return self
def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory(1, critic_factory)
return self
def with_critic2_factory_default(
self, hidden_sizes: Sequence[int] = DefaultCriticFactory.DEFAULT_HIDDEN_SIZES,
) -> TBuilder:
self: TBuilder | "_BuilderMixinDualCriticFactory"
self._with_critic_factory_default(0, hidden_sizes)
return self
class PPOExperimentBuilder(
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinSingleCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOConfig = PPOConfig()
def with_ppo_params(self, params: PPOConfig) -> "PPOExperimentBuilder":
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return PPOAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
)
class SACExperimentBuilder(
RLExperimentBuilder, _BuilderMixinActorFactory, _BuilderMixinDualCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
self._params: SACConfig = SACConfig()
def with_sac_params(self, params: SACConfig) -> "SACExperimentBuilder":
self._params = params
return self

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from torch import nn
from tianshou.highlevel.env import Environments
from tianshou.highlevel.env import Environments, EnvType
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.continuous import Critic as ContinuousCritic
@ -46,8 +46,41 @@ class ActorFactory(ABC):
m.weight.data.copy_(0.01 * m.weight.data)
class DefaultActorFactory(ActorFactory):
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(
self,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded=False,
continuous_conditioned_sigma=False,
):
self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes
"""
An actor factory which, depending on the type of environment, creates a suitable MLP-based policy
"""
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = ContinuousActorProbFactory(
self.hidden_sizes,
unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma,
)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class ContinuousActorFactory(ActorFactory, ABC):
pass
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
class ContinuousActorProbFactory(ContinuousActorFactory):
@ -85,13 +118,31 @@ class CriticFactory(ABC):
pass
class DefaultCriticFactory(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = ContinuousNetCriticFactory(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class ContinuousCriticFactory(CriticFactory, ABC):
pass
class ContinuousNetCriticFactory(ContinuousCriticFactory):
def __init__(self, hidden_sizes: Sequence[int], action_shape=0):
self.action_shape = action_shape
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:

View File

@ -29,8 +29,19 @@ class TorchOptimizerFactory(OptimizerFactory):
class AdamOptimizerFactory(OptimizerFactory):
def __init__(self, betas=(0.9, 0.999), eps=1e-08, weight_decay=0):
self.weight_decay = weight_decay
self.eps = eps
self.betas = betas
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
return Adam(module.parameters(), lr=lr)
return Adam(
module.parameters(),
lr=lr,
betas=self.betas,
eps=self.eps,
weight_decay=self.weight_decay,
)
class LRSchedulerFactory(ABC):