Support PG/Reinforce in high-level API

* Add example mujoco_reinforce_hl
* Extended functionality of ActorFactory to support creation of ModuleOpt
This commit is contained in:
Dominik Jain 2023-10-10 12:55:25 +02:00
parent 4e93c12afa
commit 6bb3abb2f0
12 changed files with 211 additions and 64 deletions

View File

@ -42,7 +42,7 @@ def main(
max_grad_norm: float = 0.5,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
log_name = os.path.join(task, "a2c", str(experiment_config.seed), now)
sampling_config = SamplingConfig(
num_epochs=epoch,
@ -75,7 +75,7 @@ def main(
),
)
.with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99))
.with_actor_factory_default(hidden_sizes)
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes)
.build()
)

View File

@ -87,7 +87,7 @@ def main(
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes)
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes)
.build()
)

View File

@ -0,0 +1,74 @@
#!/usr/bin/env python3
import os
from collections.abc import Sequence
from typing import Literal
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
PGExperimentBuilder,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PGParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 1e-3,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 1,
batch_size: int = 99999,
training_num: int = 64,
test_num: int = 10,
rew_norm: bool = True,
action_bound_method: Literal["clip", "tanh"] = "tanh",
lr_decay: bool = True,
):
log_name = os.path.join(task, "reinforce", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
experiment = (
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_pg_params(
PGParams(
discount_factor=gamma,
action_bound_method=action_bound_method,
reward_normalization=rew_norm,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
),
)
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))

View File

@ -12,13 +12,12 @@ from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module.actor import (
ActorFactory,
ActorModuleOptFactory,
)
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.module.critic import CriticFactory, CriticModuleOptFactory
from tianshou.highlevel.module.module_opt import (
ActorCriticModuleOpt,
ActorModuleOptFactory,
CriticModuleOptFactory,
ModuleOpt,
)
from tianshou.highlevel.optim import OptimizerFactory
@ -28,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
DQNParams,
Params,
ParamTransformerData,
PGParams,
PPOParams,
SACParams,
TD3Params,
@ -39,6 +39,7 @@ from tianshou.policy import (
BasePolicy,
DDPGPolicy,
DQNPolicy,
PGPolicy,
PPOPolicy,
SACPolicy,
TD3Policy,
@ -355,6 +356,41 @@ class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
class PGAgentFactory(OnpolicyAgentFactory, _ActorMixin):
def __init__(
self,
params: PGParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
_ActorMixin.__init__(self, actor_factory, optim_factory)
self.params = params
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy:
actor = self.actor_factory.create_module_opt(
envs, device, self.optim_factory, self.params.lr,
)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
device=device,
optim=actor.optim,
optim_factory=self.optim_factory,
),
)
return PGPolicy(
actor=actor.module,
optim=actor.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
**kwargs,
)
class ActorCriticAgentFactory(
Generic[TParams, TPolicy],
OnpolicyAgentFactory,

View File

@ -14,6 +14,7 @@ from tianshou.highlevel.agent import (
AgentFactory,
DDPGAgentFactory,
DQNAgentFactory,
PGAgentFactory,
PPOAgentFactory,
SACAgentFactory,
TD3AgentFactory,
@ -32,6 +33,7 @@ from tianshou.highlevel.params.policy_params import (
A2CParams,
DDPGParams,
DQNParams,
PGParams,
PPOParams,
SACParams,
TD3Params,
@ -280,7 +282,9 @@ class _BuilderMixinActorFactory:
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs
Gaussian distribution parameters.
"""
def __init__(self) -> None:
super().__init__(ContinuousActorType.GAUSSIAN)
@ -395,6 +399,35 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
return self
class PGExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
):
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
self._params: A2CParams = A2CParams()
self._env_config = None
def with_pg_params(self, params: PGParams) -> Self:
self._params = params
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
return PGAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._get_optim_factory(),
)
class A2CExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,

View File

@ -7,6 +7,8 @@ from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, Net
from tianshou.utils.string import ToStringMixin
@ -23,6 +25,21 @@ class ActorFactory(ToStringMixin, ABC):
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
pass
def create_module_opt(
self, envs: Environments, device: TDevice, optim_factory: OptimizerFactory, lr: float,
) -> ModuleOpt:
"""Creates the actor module along with its optimizer for the given learning rate.
:param envs: the environments
:param device: the torch device
:param optim_factory: the optimizer factory
:param lr: the learning rate
:return: a container with the actor module and its optimizer
"""
module = self.create_module(envs, device)
optim = optim_factory.create_optimizer(module, lr)
return ModuleOpt(module, optim)
@staticmethod
def _init_linear(actor: torch.nn.Module) -> None:
"""Initializes linear layers of an actor module using default mechanisms.
@ -154,3 +171,14 @@ class ActorFactoryDiscreteNet(ActorFactory):
hidden_sizes=(),
device=device,
).to(device)
class ActorModuleOptFactory(ToStringMixin):
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)

View File

@ -5,6 +5,8 @@ from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import Net
from tianshou.utils.string import ToStringMixin
@ -78,3 +80,20 @@ class CriticFactoryDiscreteNet(CriticFactory):
critic = discrete.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic
class CriticModuleOptFactory(ToStringMixin):
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)

View File

@ -2,13 +2,7 @@ from dataclasses import dataclass
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.string import ToStringMixin
@dataclass
@ -29,31 +23,3 @@ class ActorCriticModuleOpt:
@property
def critic(self) -> torch.nn.Module:
return self.actor_critic_module.critic
class ActorModuleOptFactory(ToStringMixin):
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)
class CriticModuleOptFactory(ToStringMixin):
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)

View File

@ -15,10 +15,10 @@ from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactory,
DistributionFunctionFactoryDefault,
)
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils import MultipleLRSchedulers
@ -277,42 +277,34 @@ class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
@dataclass
class PGParams(Params):
"""Config of general policy-gradient algorithms."""
class PGParams(Params, ParamsMixinLearningRateWithScheduler):
discount_factor: float = 0.99
reward_normalization: bool = False
deterministic_eval: bool = False
action_scaling: bool | Literal["default"] = "default"
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerActionScaling("action_scaling"))
return transformers
@dataclass
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):
vf_coef: float = 0.5
ent_coef: float = 0.01
max_grad_norm: float | None = None
gae_lambda: float = 0.95
max_batchsize: int = 256
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
transformers.append(ParamTransformerActionScaling("action_scaling"))
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
return transformers
@dataclass
class PPOParams(A2CParams):
"""PPO specific config."""
class A2CParams(PGParams):
vf_coef: float = 0.5
ent_coef: float = 0.01
max_grad_norm: float | None = None
gae_lambda: float = 0.95
max_batchsize: int = 256
@dataclass
class PPOParams(A2CParams):
eps_clip: float = 0.2
dual_clip: float | None = None
value_clip: bool = False

View File

@ -1,4 +1,3 @@
from collections.abc import Callable
from typing import Any, Literal
import gymnasium as gym

View File

@ -1,4 +1,3 @@
from collections.abc import Callable
from typing import Any, Literal, cast
import gymnasium as gym

View File

@ -1,5 +1,6 @@
import warnings
from typing import Any, Literal, cast, TypeAlias, Callable
from collections.abc import Callable
from typing import Any, Literal, TypeAlias, cast
import gymnasium as gym
import numpy as np