Add A2C high-level API

* Add common based class for A2C and PPO agent factories
* Add default for dist_fn parameter, adding corresponding factories
* Add example mujoco_a2c_hl
This commit is contained in:
Dominik Jain 2023-09-28 14:28:03 +02:00
parent acd89fa3b0
commit cd79cf8661
7 changed files with 290 additions and 27 deletions

View File

@ -0,0 +1,85 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from typing import Literal
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
RLExperimentConfig,
)
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import A2CParams
def main(
experiment_config: RLExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 7e-4,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 30000,
step_per_collect: int = 80,
repeat_per_collect: int = 1,
batch_size: int = 99999,
training_num: int = 16,
test_num: int = 10,
rew_norm: bool = True,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
gae_lambda: float = 0.95,
bound_action_method: Literal["clip", "tanh"] = "clip",
lr_decay: bool = True,
max_grad_norm: float = 0.5,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
A2CExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_a2c_params(
A2CParams(
discount_factor=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
reward_normalization=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
),
)
.with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99))
.with_actor_factory_default(hidden_sizes)
.with_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)

View File

@ -65,8 +65,8 @@ def main(
return Independent(Normal(*logits), 1) return Independent(Normal(*logits), 1)
experiment = ( experiment = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn) PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
.with_params( .with_ppo_params(
PPOParams( PPOParams(
discount_factor=gamma, discount_factor=gamma,
gae_lambda=gae_lambda, gae_lambda=gae_lambda,
@ -84,6 +84,7 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay if lr_decay
else None, else None,
dist_fn=dist_fn,
), ),
) )
.with_actor_factory_default(hidden_sizes) .with_actor_factory_default(hidden_sizes)

View File

@ -1,6 +1,7 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import Generic, TypeVar
import torch import torch
@ -19,18 +20,21 @@ from tianshou.highlevel.module import (
) )
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.policy_params import ( from tianshou.highlevel.params.policy_params import (
A2CParams,
Params,
ParamTransformerData, ParamTransformerData,
PPOParams, PPOParams,
SACParams, SACParams,
TD3Params, TD3Params,
) )
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params)
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class AgentFactory(ABC): class AgentFactory(ABC):
@ -219,15 +223,20 @@ class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr) return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin): class ActorCriticAgentFactory(
Generic[TParams, TPolicy],
OnpolicyAgentFactory,
_ActorCriticMixin,
ABC,
):
def __init__( def __init__(
self, self,
params: PPOParams, params: TParams,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution], policy_class: type[TPolicy],
): ):
super().__init__(sampling_config) super().__init__(sampling_config)
_ActorCriticMixin.__init__( _ActorCriticMixin.__init__(
@ -238,10 +247,14 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
critic_use_action=False, critic_use_action=False,
) )
self.params = params self.params = params
self.dist_fn = dist_fn self.policy_class = policy_class
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: @abstractmethod
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
pass
def _create_kwargs(self, envs: Environments, device: TDevice):
actor_critic = self._create_actor_critic(envs, device)
kwargs = self.params.create_kwargs( kwargs = self.params.create_kwargs(
ParamTransformerData( ParamTransformerData(
envs=envs, envs=envs,
@ -250,15 +263,59 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
optim=actor_critic.optim, optim=actor_critic.optim,
), ),
) )
return PPOPolicy( kwargs["actor"] = actor_critic.actor
actor=actor_critic.actor, kwargs["critic"] = actor_critic.critic
critic=actor_critic.critic, kwargs["optim"] = actor_critic.optim
optim=actor_critic.optim, kwargs["action_space"] = envs.get_action_space()
dist_fn=self.dist_fn, return kwargs
action_space=envs.get_action_space(),
**kwargs, def create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
return self.policy_class(**self._create_kwargs(envs, device))
class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
def __init__(
self,
params: A2CParams,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
A2CPolicy,
) )
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
def __init__(
self,
params: PPOParams,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
):
super().__init__(
params,
sampling_config,
actor_factory,
critic_factory,
optimizer_factory,
PPOPolicy,
)
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin): class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__( def __init__(

View File

@ -9,6 +9,7 @@ import torch
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.highlevel.agent import ( from tianshou.highlevel.agent import (
A2CAgentFactory,
AgentFactory, AgentFactory,
PPOAgentFactory, PPOAgentFactory,
SACAgentFactory, SACAgentFactory,
@ -25,10 +26,14 @@ from tianshou.highlevel.module import (
CriticFactoryDefault, CriticFactoryDefault,
) )
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params from tianshou.highlevel.params.policy_params import (
A2CParams,
PPOParams,
SACParams,
TD3Params,
)
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -234,7 +239,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
def __init__(self): def __init__(self):
super().__init__(ContinuousActorType.DETERMINISTIC) super().__init__(ContinuousActorType.GAUSSIAN)
def with_actor_factory_default( def with_actor_factory_default(
self, self,
@ -343,6 +348,39 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self return self
class A2CExperimentBuilder(
RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinSingleCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
env_config: PersistableConfigProtocol | None = None,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self)
self._params: A2CParams = A2CParams()
self._env_config = env_config
def with_a2c_params(self, params: A2CParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return A2CAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_critic_factory(0),
self._get_optim_factory(),
)
class PPOExperimentBuilder( class PPOExperimentBuilder(
RLExperimentBuilder, RLExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinActorFactory_ContinuousGaussian,
@ -353,14 +391,12 @@ class PPOExperimentBuilder(
experiment_config: RLExperimentConfig, experiment_config: RLExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
env_config: PersistableConfigProtocol | None = None, env_config: PersistableConfigProtocol | None = None,
): ):
super().__init__(experiment_config, env_factory, sampling_config, env_config=env_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
_BuilderMixinSingleCriticFactory.__init__(self) _BuilderMixinSingleCriticFactory.__init__(self)
self._params: PPOParams = PPOParams() self._params: PPOParams = PPOParams()
self._dist_fn = dist_fn
self._env_config = env_config self._env_config = env_config
def with_ppo_params(self, params: PPOParams) -> Self: def with_ppo_params(self, params: PPOParams) -> Self:
@ -375,7 +411,6 @@ class PPOExperimentBuilder(
self._get_actor_factory(), self._get_actor_factory(),
self._get_critic_factory(0), self._get_critic_factory(0),
self._get_optim_factory(), self._get_optim_factory(),
self._dist_fn,
) )

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from typing import Any from typing import Any
import torch import torch
from torch.optim import Adam from torch.optim import Adam, RMSprop
class OptimizerFactory(ABC): class OptimizerFactory(ABC):
@ -43,3 +43,23 @@ class OptimizerFactoryAdam(OptimizerFactory):
eps=self.eps, eps=self.eps,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
) )
class OptimizerFactoryRMSprop(OptimizerFactory):
def __init__(self, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False):
self.alpha = alpha
self.momentum = momentum
self.centered = centered
self.weight_decay = weight_decay
self.eps = eps
def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop:
return RMSprop(
module.parameters(),
lr=lr,
alpha=self.alpha,
eps=self.eps,
weight_decay=self.weight_decay,
momentum=self.momentum,
centered=self.centered,
)

View File

@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TypeAlias
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistParams
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
class DistributionFunctionFactory(ABC):
@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
pass
def _dist_fn_categorical(p):
return torch.distributions.Categorical(logits=p)
def _dist_fn_gaussian(*p):
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
match envs.get_type():
case EnvType.DISCRETE:
return _dist_fn_categorical
case EnvType.CONTINUOUS:
return _dist_fn_gaussian
case _:
raise ValueError(envs.get_type())

View File

@ -9,6 +9,11 @@ from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import ModuleOpt, TDevice from tianshou.highlevel.module import ModuleOpt, TDevice
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactory,
DistributionFunctionFactoryDefault,
TDistributionFunction,
)
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory from tianshou.highlevel.params.noise import NoiseFactory
@ -34,6 +39,12 @@ class ParamTransformerData:
class ParamTransformer(ABC): class ParamTransformer(ABC):
"""Transforms one or more parameters from the representation used by the high-level API
to the representation required by the (low-level) policy implementation.
It operates directly on a dictionary of keyword arguments, which is initially
generated from the parameter dataclass (subclass of `Params`).
"""
@abstractmethod @abstractmethod
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
pass pass
@ -159,6 +170,18 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer):
kwargs[self.key] = value.create_value(data.envs) kwargs[self.key] = value.create_value(data.envs)
class ParamTransformerDistributionFunction(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
value = kwargs[self.key]
if value == "default":
kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
elif isinstance(value, DistributionFunctionFactory):
kwargs[self.key] = value.create_dist_fn(data.envs)
class GetParamTransformersProtocol(Protocol): class GetParamTransformersProtocol(Protocol):
def _get_param_transformers(self) -> list[ParamTransformer]: def _get_param_transformers(self) -> list[ParamTransformer]:
pass pass
@ -200,16 +223,23 @@ class PGParams(Params):
@dataclass @dataclass
class A2CParams(PGParams): class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):
vf_coef: float = 0.5 vf_coef: float = 0.5
ent_coef: float = 0.01 ent_coef: float = 0.01
max_grad_norm: float | None = None max_grad_norm: float | None = None
gae_lambda: float = 0.95 gae_lambda: float = 0.95
max_batchsize: int = 256 max_batchsize: int = 256
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
return transformers
@dataclass @dataclass
class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler): class PPOParams(A2CParams):
"""PPO specific config.""" """PPO specific config."""
eps_clip: float = 0.2 eps_clip: float = 0.2