Add A2C high-level API
* Add common based class for A2C and PPO agent factories * Add default for dist_fn parameter, adding corresponding factories * Add example mujoco_a2c_hl
This commit is contained in:
parent
acd89fa3b0
commit
cd79cf8661
85
examples/mujoco/mujoco_a2c_hl.py
Normal file
85
examples/mujoco/mujoco_a2c_hl.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
A2CExperimentBuilder,
|
||||||
|
RLExperimentConfig,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
|
||||||
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
|
from tianshou.highlevel.params.policy_params import A2CParams
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
task: str = "Ant-v3",
|
||||||
|
buffer_size: int = 4096,
|
||||||
|
hidden_sizes: Sequence[int] = (64, 64),
|
||||||
|
lr: float = 7e-4,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
epoch: int = 100,
|
||||||
|
step_per_epoch: int = 30000,
|
||||||
|
step_per_collect: int = 80,
|
||||||
|
repeat_per_collect: int = 1,
|
||||||
|
batch_size: int = 99999,
|
||||||
|
training_num: int = 16,
|
||||||
|
test_num: int = 10,
|
||||||
|
rew_norm: bool = True,
|
||||||
|
vf_coef: float = 0.5,
|
||||||
|
ent_coef: float = 0.01,
|
||||||
|
gae_lambda: float = 0.95,
|
||||||
|
bound_action_method: Literal["clip", "tanh"] = "clip",
|
||||||
|
lr_decay: bool = True,
|
||||||
|
max_grad_norm: float = 0.5,
|
||||||
|
):
|
||||||
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
|
|
||||||
|
sampling_config = RLSamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
repeat_per_collect=repeat_per_collect,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
|
experiment = (
|
||||||
|
A2CExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||||
|
.with_a2c_params(
|
||||||
|
A2CParams(
|
||||||
|
discount_factor=gamma,
|
||||||
|
gae_lambda=gae_lambda,
|
||||||
|
action_bound_method=bound_action_method,
|
||||||
|
reward_normalization=rew_norm,
|
||||||
|
ent_coef=ent_coef,
|
||||||
|
vf_coef=vf_coef,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
|
lr=lr,
|
||||||
|
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||||
|
if lr_decay
|
||||||
|
else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99))
|
||||||
|
.with_actor_factory_default(hidden_sizes)
|
||||||
|
.with_critic_factory_default(hidden_sizes)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
CLI(main)
|
@ -65,8 +65,8 @@ def main(
|
|||||||
return Independent(Normal(*logits), 1)
|
return Independent(Normal(*logits), 1)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
|
PPOExperimentBuilder(experiment_config, env_factory, sampling_config)
|
||||||
.with_params(
|
.with_ppo_params(
|
||||||
PPOParams(
|
PPOParams(
|
||||||
discount_factor=gamma,
|
discount_factor=gamma,
|
||||||
gae_lambda=gae_lambda,
|
gae_lambda=gae_lambda,
|
||||||
@ -84,6 +84,7 @@ def main(
|
|||||||
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||||
if lr_decay
|
if lr_decay
|
||||||
else None,
|
else None,
|
||||||
|
dist_fn=dist_fn,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory_default(hidden_sizes)
|
.with_actor_factory_default(hidden_sizes)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -19,18 +20,21 @@ from tianshou.highlevel.module import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.highlevel.params.policy_params import (
|
from tianshou.highlevel.params.policy_params import (
|
||||||
|
A2CParams,
|
||||||
|
Params,
|
||||||
ParamTransformerData,
|
ParamTransformerData,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
)
|
)
|
||||||
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
from tianshou.policy import A2CPolicy, BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
||||||
from tianshou.policy.modelfree.pg import TDistParams
|
|
||||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||||
from tianshou.utils.net.common import ActorCritic
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
|
||||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||||
|
TParams = TypeVar("TParams", bound=Params)
|
||||||
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
|
||||||
|
|
||||||
class AgentFactory(ABC):
|
class AgentFactory(ABC):
|
||||||
@ -219,15 +223,20 @@ class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
|||||||
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
||||||
|
|
||||||
|
|
||||||
class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
class ActorCriticAgentFactory(
|
||||||
|
Generic[TParams, TPolicy],
|
||||||
|
OnpolicyAgentFactory,
|
||||||
|
_ActorCriticMixin,
|
||||||
|
ABC,
|
||||||
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: PPOParams,
|
params: TParams,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
policy_class: type[TPolicy],
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config)
|
super().__init__(sampling_config)
|
||||||
_ActorCriticMixin.__init__(
|
_ActorCriticMixin.__init__(
|
||||||
@ -238,10 +247,14 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
|||||||
critic_use_action=False,
|
critic_use_action=False,
|
||||||
)
|
)
|
||||||
self.params = params
|
self.params = params
|
||||||
self.dist_fn = dist_fn
|
self.policy_class = policy_class
|
||||||
|
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
@abstractmethod
|
||||||
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_kwargs(self, envs: Environments, device: TDevice):
|
||||||
|
actor_critic = self._create_actor_critic(envs, device)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerData(
|
ParamTransformerData(
|
||||||
envs=envs,
|
envs=envs,
|
||||||
@ -250,15 +263,59 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
|||||||
optim=actor_critic.optim,
|
optim=actor_critic.optim,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return PPOPolicy(
|
kwargs["actor"] = actor_critic.actor
|
||||||
actor=actor_critic.actor,
|
kwargs["critic"] = actor_critic.critic
|
||||||
critic=actor_critic.critic,
|
kwargs["optim"] = actor_critic.optim
|
||||||
optim=actor_critic.optim,
|
kwargs["action_space"] = envs.get_action_space()
|
||||||
dist_fn=self.dist_fn,
|
return kwargs
|
||||||
action_space=envs.get_action_space(),
|
|
||||||
**kwargs,
|
def create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||||
|
return self.policy_class(**self._create_kwargs(envs, device))
|
||||||
|
|
||||||
|
|
||||||
|
class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: A2CParams,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optimizer_factory: OptimizerFactory,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
params,
|
||||||
|
sampling_config,
|
||||||
|
actor_factory,
|
||||||
|
critic_factory,
|
||||||
|
optimizer_factory,
|
||||||
|
A2CPolicy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
|
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
|
|
||||||
|
|
||||||
|
class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: PPOParams,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optimizer_factory: OptimizerFactory,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
params,
|
||||||
|
sampling_config,
|
||||||
|
actor_factory,
|
||||||
|
critic_factory,
|
||||||
|
optimizer_factory,
|
||||||
|
PPOPolicy,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
|
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
|
|
||||||
|
|
||||||
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -9,6 +9,7 @@ import torch
|
|||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.highlevel.agent import (
|
from tianshou.highlevel.agent import (
|
||||||
|
A2CAgentFactory,
|
||||||
AgentFactory,
|
AgentFactory,
|
||||||
PPOAgentFactory,
|
PPOAgentFactory,
|
||||||
SACAgentFactory,
|
SACAgentFactory,
|
||||||
@ -25,10 +26,14 @@ from tianshou.highlevel.module import (
|
|||||||
CriticFactoryDefault,
|
CriticFactoryDefault,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams, SACParams, TD3Params
|
from tianshou.highlevel.params.policy_params import (
|
||||||
|
A2CParams,
|
||||||
|
PPOParams,
|
||||||
|
SACParams,
|
||||||
|
TD3Params,
|
||||||
|
)
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.policy.modelfree.pg import TDistParams
|
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
|
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
@ -234,7 +239,7 @@ class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
|||||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(ContinuousActorType.DETERMINISTIC)
|
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||||
|
|
||||||
def with_actor_factory_default(
|
def with_actor_factory_default(
|
||||||
self,
|
self,
|
||||||
@ -343,6 +348,39 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class A2CExperimentBuilder(
|
||||||
|
RLExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
|
_BuilderMixinSingleCriticFactory,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_config: RLExperimentConfig,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
env_config: PersistableConfigProtocol | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
|
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||||
|
self._params: A2CParams = A2CParams()
|
||||||
|
self._env_config = env_config
|
||||||
|
|
||||||
|
def with_a2c_params(self, params: A2CParams) -> Self:
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
return A2CAgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_actor_factory(),
|
||||||
|
self._get_critic_factory(0),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PPOExperimentBuilder(
|
class PPOExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
RLExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
@ -353,14 +391,12 @@ class PPOExperimentBuilder(
|
|||||||
experiment_config: RLExperimentConfig,
|
experiment_config: RLExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
|
||||||
env_config: PersistableConfigProtocol | None = None,
|
env_config: PersistableConfigProtocol | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config, env_config=env_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||||
self._params: PPOParams = PPOParams()
|
self._params: PPOParams = PPOParams()
|
||||||
self._dist_fn = dist_fn
|
|
||||||
self._env_config = env_config
|
self._env_config = env_config
|
||||||
|
|
||||||
def with_ppo_params(self, params: PPOParams) -> Self:
|
def with_ppo_params(self, params: PPOParams) -> Self:
|
||||||
@ -375,7 +411,6 @@ class PPOExperimentBuilder(
|
|||||||
self._get_actor_factory(),
|
self._get_actor_factory(),
|
||||||
self._get_critic_factory(0),
|
self._get_critic_factory(0),
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
self._dist_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam, RMSprop
|
||||||
|
|
||||||
|
|
||||||
class OptimizerFactory(ABC):
|
class OptimizerFactory(ABC):
|
||||||
@ -43,3 +43,23 @@ class OptimizerFactoryAdam(OptimizerFactory):
|
|||||||
eps=self.eps,
|
eps=self.eps,
|
||||||
weight_decay=self.weight_decay,
|
weight_decay=self.weight_decay,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerFactoryRMSprop(OptimizerFactory):
|
||||||
|
def __init__(self, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False):
|
||||||
|
self.alpha = alpha
|
||||||
|
self.momentum = momentum
|
||||||
|
self.centered = centered
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop:
|
||||||
|
return RMSprop(
|
||||||
|
module.parameters(),
|
||||||
|
lr=lr,
|
||||||
|
alpha=self.alpha,
|
||||||
|
eps=self.eps,
|
||||||
|
weight_decay=self.weight_decay,
|
||||||
|
momentum=self.momentum,
|
||||||
|
centered=self.centered,
|
||||||
|
)
|
||||||
|
35
tianshou/highlevel/params/dist_fn.py
Normal file
35
tianshou/highlevel/params/dist_fn.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TypeAlias
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
|
from tianshou.policy.modelfree.pg import TDistParams
|
||||||
|
|
||||||
|
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionFunctionFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _dist_fn_categorical(p):
|
||||||
|
return torch.distributions.Categorical(logits=p)
|
||||||
|
|
||||||
|
|
||||||
|
def _dist_fn_gaussian(*p):
|
||||||
|
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
|
||||||
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||||
|
match envs.get_type():
|
||||||
|
case EnvType.DISCRETE:
|
||||||
|
return _dist_fn_categorical
|
||||||
|
case EnvType.CONTINUOUS:
|
||||||
|
return _dist_fn_gaussian
|
||||||
|
case _:
|
||||||
|
raise ValueError(envs.get_type())
|
@ -9,6 +9,11 @@ from tianshou.highlevel.env import Environments
|
|||||||
from tianshou.highlevel.module import ModuleOpt, TDevice
|
from tianshou.highlevel.module import ModuleOpt, TDevice
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
||||||
|
from tianshou.highlevel.params.dist_fn import (
|
||||||
|
DistributionFunctionFactory,
|
||||||
|
DistributionFunctionFactoryDefault,
|
||||||
|
TDistributionFunction,
|
||||||
|
)
|
||||||
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||||
from tianshou.highlevel.params.noise import NoiseFactory
|
from tianshou.highlevel.params.noise import NoiseFactory
|
||||||
@ -34,6 +39,12 @@ class ParamTransformerData:
|
|||||||
|
|
||||||
|
|
||||||
class ParamTransformer(ABC):
|
class ParamTransformer(ABC):
|
||||||
|
"""Transforms one or more parameters from the representation used by the high-level API
|
||||||
|
to the representation required by the (low-level) policy implementation.
|
||||||
|
It operates directly on a dictionary of keyword arguments, which is initially
|
||||||
|
generated from the parameter dataclass (subclass of `Params`).
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
pass
|
pass
|
||||||
@ -159,6 +170,18 @@ class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
|||||||
kwargs[self.key] = value.create_value(data.envs)
|
kwargs[self.key] = value.create_value(data.envs)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTransformerDistributionFunction(ParamTransformer):
|
||||||
|
def __init__(self, key: str):
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||||
|
value = kwargs[self.key]
|
||||||
|
if value == "default":
|
||||||
|
kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
|
||||||
|
elif isinstance(value, DistributionFunctionFactory):
|
||||||
|
kwargs[self.key] = value.create_dist_fn(data.envs)
|
||||||
|
|
||||||
|
|
||||||
class GetParamTransformersProtocol(Protocol):
|
class GetParamTransformersProtocol(Protocol):
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
pass
|
pass
|
||||||
@ -200,16 +223,23 @@ class PGParams(Params):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class A2CParams(PGParams):
|
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):
|
||||||
vf_coef: float = 0.5
|
vf_coef: float = 0.5
|
||||||
ent_coef: float = 0.01
|
ent_coef: float = 0.01
|
||||||
max_grad_norm: float | None = None
|
max_grad_norm: float | None = None
|
||||||
gae_lambda: float = 0.95
|
gae_lambda: float = 0.95
|
||||||
max_batchsize: int = 256
|
max_batchsize: int = 256
|
||||||
|
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
||||||
|
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
|
class PPOParams(A2CParams):
|
||||||
"""PPO specific config."""
|
"""PPO specific config."""
|
||||||
|
|
||||||
eps_clip: float = 0.2
|
eps_clip: float = 0.2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user