Improve high-level policy parametrisation

Policy objects are now parametrised by converting the parameter
dataclass instances to kwargs, using some injectable conversions
along the way
This commit is contained in:
Dominik Jain 2023-09-25 17:56:37 +02:00
parent 37dc07e487
commit 367778d37f
11 changed files with 254 additions and 165 deletions

View File

@ -9,13 +9,13 @@ from jsonargparse import CLI
from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import PPOConfig
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
PPOExperimentBuilder,
RLExperimentConfig,
)
from tianshou.highlevel.optim import LinearLRSchedulerFactory
from tianshou.highlevel.params.lr_scheduler import LinearLRSchedulerFactory
from tianshou.highlevel.params.policy_params import PPOParams
def main(
@ -65,22 +65,21 @@ def main(
return Independent(Normal(*logits), 1)
experiment = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
.with_ppo_params(
PPOConfig(
gamma=gamma,
PPOParams(
discount_factor=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
rew_norm=rew_norm,
reward_normalization=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
norm_adv=norm_adv,
advantage_normalization=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_adv=recompute_adv,
dist_fn=dist_fn,
recompute_advantage=recompute_adv,
lr=lr,
lr_scheduler_factory=LinearLRSchedulerFactory(sampling_config)
if lr_decay

View File

@ -7,7 +7,8 @@ from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACConfig
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.highlevel.params.alpha import DefaultAutoAlphaFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
RLExperimentConfig,
@ -58,7 +59,7 @@ def main(
experiment = (
SACExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_sac_params(
SACConfig(
SACParams(
tau=tau,
gamma=gamma,
alpha=DefaultAutoAlphaFactory(lr=alpha_lr) if auto_alpha else alpha,

View File

@ -1,10 +1,8 @@
import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Literal
from typing import Dict, Any, List, Tuple
import numpy as np
import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
@ -13,9 +11,14 @@ from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model"
@ -128,178 +131,132 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
)
@dataclass
class RLAgentConfig:
"""Config common to most RL algorithms."""
class ParamTransformerDrop(ParamTransformer):
def __init__(self, *keys: str):
self.keys = keys
gamma: float = 0.99
"""Discount factor"""
gae_lambda: float = 0.95
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
rew_norm: bool = True
"""Whether to normalize rewards"""
def transform(self, kwargs: Dict[str, Any]) -> None:
for k in self.keys:
del kwargs[k]
@dataclass
class PGConfig(RLAgentConfig):
"""Config of general policy-gradient algorithms."""
class ParamTransformerLRScheduler(ParamTransformer):
def __init__(self, optim: torch.optim.Optimizer):
self.optim = optim
ent_coef: float = 0.0
vf_coef: float = 0.25
max_grad_norm: float = 0.5
@dataclass
class PPOConfig(PGConfig):
"""PPO specific config."""
value_clip: bool = False
norm_adv: bool = False
"""Whether to normalize advantages"""
eps_clip: float = 0.2
dual_clip: float | None = None
recompute_adv: bool = True
dist_fn: Callable = None
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
def transform(self, kwargs: Dict[str, Any]) -> None:
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None
class PPOAgentFactory(OnpolicyAgentFactory):
def __init__(
self,
config: PPOConfig,
params: PPOParams,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
):
super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory
self.actor_factory = actor_factory
self.config = config
self.config = params
self.dist_fn = dist_fn
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=False)
actor_critic = ActorCritic(actor, critic)
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
if self.config.lr_scheduler_factory is not None:
lr_scheduler = self.config.lr_scheduler_factory.create_scheduler(optim)
else:
lr_scheduler = None
kwargs = self.config.create_kwargs(
ParamTransformerDrop("lr"),
ParamTransformerLRScheduler(optim))
return PPOPolicy(
# nn-stuff
actor,
critic,
optim,
dist_fn=self.config.dist_fn,
lr_scheduler=lr_scheduler,
# env-stuff
actor=actor,
critic=critic,
optim=optim,
dist_fn=self.dist_fn,
action_space=envs.get_action_space(),
action_scaling=True,
# general_config
discount_factor=self.config.gamma,
gae_lambda=self.config.gae_lambda,
reward_normalization=self.config.rew_norm,
action_bound_method=self.config.action_bound_method,
# pg_config
max_grad_norm=self.config.max_grad_norm,
vf_coef=self.config.vf_coef,
ent_coef=self.config.ent_coef,
# ppo_config
eps_clip=self.config.eps_clip,
value_clip=self.config.value_clip,
dual_clip=self.config.dual_clip,
advantage_normalization=self.config.norm_adv,
recompute_advantage=self.config.recompute_adv,
**kwargs
)
class AutoAlphaFactory(ABC):
@abstractmethod
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
):
pass
class ParamTransformerAlpha(ParamTransformer):
def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice):
self.envs = envs
self.optim_factory = optim_factory
self.device = device
def transform(self, kwargs: Dict[str, Any]) -> None:
key = "alpha"
alpha = self.get(kwargs, key)
if isinstance(alpha, AutoAlphaFactory):
kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device)
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4):
self.lr = lr
class ParamTransformerMultiLRScheduler(ParamTransformer):
def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]):
self.optim_key_list = optim_key_list
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim
@dataclass
class SACConfig:
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
reward_normalization: bool = False
estimation_step: int = 1
deterministic_eval: bool = True
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
def transform(self, kwargs: Dict[str, Any]) -> None:
lr_schedulers = []
for optim, lr_scheduler_factory_key in self.optim_key_list:
lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True)
if lr_scheduler_factory is not None:
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
match len(lr_schedulers):
case 0:
lr_scheduler = None
case 1:
lr_scheduler = lr_schedulers[0]
case _:
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
kwargs["lr_scheduler"] = lr_scheduler
class SACAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
config: SACConfig,
params: SACParams,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
exploration_noise: BaseNoise | None = None,
):
super().__init__(sampling_config)
self.critic2_factory = critic2_factory
self.critic1_factory = critic1_factory
self.actor_factory = actor_factory
self.exploration_noise = exploration_noise
self.optim_factory = optim_factory
self.config = config
self.params = params
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module(envs, device)
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
if isinstance(self.config.alpha, AutoAlphaFactory):
alpha = self.config.alpha.create_auto_alpha(envs, self.optim_factory, device)
else:
alpha = self.config.alpha
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr)
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr)
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr)
kwargs = self.params.create_kwargs(
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
ParamTransformerMultiLRScheduler([
(actor_optim, "actor_lr_scheduler_factory"),
(critic1_optim, "critic1_lr_scheduler_factory"),
(critic2_optim, "critic2_lr_scheduler_factory")]
),
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device))
return SACPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
tau=self.config.tau,
gamma=self.config.gamma,
alpha=alpha,
estimation_step=self.config.estimation_step,
actor=actor,
actor_optim=actor_optim,
critic=critic1,
critic_optim=critic1_optim,
critic2=critic2,
critic2_optim=critic2_optim,
action_space=envs.get_action_space(),
deterministic_eval=self.config.deterministic_eval,
exploration_noise=self.exploration_noise,
observation_space=envs.get_observation_space(),
**kwargs
)

View File

@ -28,19 +28,22 @@ class Environments(ABC):
self.test_envs = test_envs
def info(self) -> dict[str, Any]:
return {"action_shape": self.get_action_shape(), "state_shape": self.get_state_shape()}
return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()}
@abstractmethod
def get_action_shape(self) -> TShape:
pass
@abstractmethod
def get_state_shape(self) -> TShape:
def get_observation_shape(self) -> TShape:
pass
def get_action_space(self) -> gym.Space:
return self.env.action_space
def get_observation_space(self) -> gym.Space:
return self.env.observation_space
@abstractmethod
def get_type(self) -> EnvType:
pass
@ -75,7 +78,7 @@ class ContinuousEnvironments(Environments):
def get_action_shape(self) -> TShape:
return self.action_shape
def get_state_shape(self) -> TShape:
def get_observation_shape(self) -> TShape:
return self.state_shape
def get_type(self):

View File

@ -2,13 +2,13 @@ from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from pprint import pprint
from typing import Generic, TypeVar
from typing import Generic, TypeVar, Callable
import numpy as np
import torch
from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, PPOConfig, SACConfig
from tianshou.highlevel.agent import AgentFactory, PPOAgentFactory, SACAgentFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
@ -19,7 +19,9 @@ from tianshou.highlevel.module import (
DefaultCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory, OptimizerFactory
from tianshou.highlevel.params.policy_params import PPOParams, SACParams
from tianshou.policy import BasePolicy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -294,13 +296,15 @@ class PPOExperimentBuilder(
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOConfig = PPOConfig()
self._params: PPOParams = PPOParams()
self._dist_fn = dist_fn
def with_ppo_params(self, params: PPOConfig) -> "PPOExperimentBuilder":
def with_ppo_params(self, params: PPOParams) -> "PPOExperimentBuilder":
self._params = params
return self
@ -312,6 +316,7 @@ class PPOExperimentBuilder(
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
self._dist_fn
)
@ -327,8 +332,12 @@ class SACExperimentBuilder(
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self)
self._params: SACConfig = SACConfig()
self._params: SACParams = SACParams()
def with_sac_params(self, params: SACConfig) -> "SACExperimentBuilder":
def with_sac_params(self, params: SACParams) -> "SACExperimentBuilder":
self._params = params
return self
def _create_agent_factory(self) -> AgentFactory:
return SACAgentFactory(self._params, self._sampling_config, self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(1), self._get_optim_factory())

View File

@ -91,7 +91,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net(
envs.get_state_shape(),
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
activation=nn.Tanh,
device=device,
@ -148,7 +148,7 @@ class ContinuousNetCriticFactory(ContinuousCriticFactory):
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_state_shape(),
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,

View File

@ -2,13 +2,9 @@ from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any
import numpy as np
import torch
from torch import Tensor
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import RLSamplingConfig
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
@ -44,19 +40,3 @@ class AdamOptimizerFactory(OptimizerFactory):
)
class LRSchedulerFactory(ABC):
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass
class LinearLRSchedulerFactory(LRSchedulerFactory):
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
max_update_num = (
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
* self.sampling_config.num_epochs
)
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

View File

View File

@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import TDevice
from tianshou.highlevel.optim import OptimizerFactory
class AutoAlphaFactory(ABC):
@abstractmethod
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
pass
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4):
self.lr = lr
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim

View File

@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
import numpy as np
import torch
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
from tianshou.highlevel.config import RLSamplingConfig
class LRSchedulerFactory(ABC):
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass
class LinearLRSchedulerFactory(LRSchedulerFactory):
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
max_update_num = (
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
* self.sampling_config.num_epochs
)
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

View File

@ -0,0 +1,80 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from typing import Dict, Any, Literal
import torch
from tianshou.exploration import BaseNoise
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
class ParamTransformer(ABC):
@abstractmethod
def transform(self, kwargs: Dict[str, Any]) -> None:
pass
@staticmethod
def get(d: Dict[str, Any], key: str, drop: bool = False) -> Any:
value = d[key]
if drop:
del d[key]
return value
@dataclass
class Params:
def create_kwargs(self, *transformers: ParamTransformer) -> Dict[str, Any]:
d = asdict(self)
for transformer in transformers:
transformer.transform(d)
return d
@dataclass
class PGParams(Params):
"""Config of general policy-gradient algorithms."""
discount_factor: float = 0.99
reward_normalization: bool = False
deterministic_eval: bool = False
action_scaling: bool = True
action_bound_method: Literal["clip", "tanh"] | None = "clip"
@dataclass
class A2CParams(PGParams):
vf_coef: float = 0.5
ent_coef: float = 0.01
max_grad_norm: float | None = None
gae_lambda: float = 0.95
max_batchsize: int = 256
@dataclass
class PPOParams(A2CParams):
"""PPO specific config."""
eps_clip: float = 0.2
dual_clip: float | None = None
value_clip: bool = False
advantage_normalization: bool = True
recompute_advantage: bool = False
lr: float = 1e-3
lr_scheduler_factory: LRSchedulerFactory | None = None
@dataclass
class SACParams(Params):
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
estimation_step: int = 1
exploration_noise: BaseNoise | Literal["default"] | None = None
deterministic_eval: bool = True
action_scaling: bool = True
action_bound_method: Literal["clip"] | None = "clip"
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None