Support PG/Reinforce in high-level API
* Add example mujoco_reinforce_hl * Extended functionality of ActorFactory to support creation of ModuleOpt
This commit is contained in:
parent
4e93c12afa
commit
6bb3abb2f0
@ -42,7 +42,7 @@ def main(
|
|||||||
max_grad_norm: float = 0.5,
|
max_grad_norm: float = 0.5,
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "a2c", str(experiment_config.seed), now)
|
||||||
|
|
||||||
sampling_config = SamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
@ -75,7 +75,7 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99))
|
.with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99))
|
||||||
.with_actor_factory_default(hidden_sizes)
|
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
|
||||||
.with_critic_factory_default(hidden_sizes)
|
.with_critic_factory_default(hidden_sizes)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
@ -87,7 +87,7 @@ def main(
|
|||||||
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory_default(hidden_sizes)
|
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
|
||||||
.with_critic_factory_default(hidden_sizes)
|
.with_critic_factory_default(hidden_sizes)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
74
examples/mujoco/mujoco_reinforce_hl.py
Normal file
74
examples/mujoco/mujoco_reinforce_hl.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
ExperimentConfig,
|
||||||
|
PGExperimentBuilder,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
|
from tianshou.highlevel.params.policy_params import PGParams
|
||||||
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
|
task: str = "Ant-v3",
|
||||||
|
buffer_size: int = 4096,
|
||||||
|
hidden_sizes: Sequence[int] = (64, 64),
|
||||||
|
lr: float = 1e-3,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
epoch: int = 100,
|
||||||
|
step_per_epoch: int = 30000,
|
||||||
|
step_per_collect: int = 2048,
|
||||||
|
repeat_per_collect: int = 1,
|
||||||
|
batch_size: int = 99999,
|
||||||
|
training_num: int = 64,
|
||||||
|
test_num: int = 10,
|
||||||
|
rew_norm: bool = True,
|
||||||
|
action_bound_method: Literal["clip", "tanh"] = "tanh",
|
||||||
|
lr_decay: bool = True,
|
||||||
|
):
|
||||||
|
log_name = os.path.join(task, "reinforce", str(experiment_config.seed), datetime_tag())
|
||||||
|
|
||||||
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
repeat_per_collect=repeat_per_collect,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
|
experiment = (
|
||||||
|
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
.with_pg_params(
|
||||||
|
PGParams(
|
||||||
|
discount_factor=gamma,
|
||||||
|
action_bound_method=action_bound_method,
|
||||||
|
reward_normalization=rew_norm,
|
||||||
|
lr=lr,
|
||||||
|
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||||
|
if lr_decay
|
||||||
|
else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.run_main(lambda: CLI(main))
|
@ -12,13 +12,12 @@ from tianshou.highlevel.env import Environments
|
|||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import Logger
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
|
ActorModuleOptFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.core import TDevice
|
from tianshou.highlevel.module.core import TDevice
|
||||||
from tianshou.highlevel.module.critic import CriticFactory
|
from tianshou.highlevel.module.critic import CriticFactory, CriticModuleOptFactory
|
||||||
from tianshou.highlevel.module.module_opt import (
|
from tianshou.highlevel.module.module_opt import (
|
||||||
ActorCriticModuleOpt,
|
ActorCriticModuleOpt,
|
||||||
ActorModuleOptFactory,
|
|
||||||
CriticModuleOptFactory,
|
|
||||||
ModuleOpt,
|
ModuleOpt,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
@ -28,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
DQNParams,
|
DQNParams,
|
||||||
Params,
|
Params,
|
||||||
ParamTransformerData,
|
ParamTransformerData,
|
||||||
|
PGParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
@ -39,6 +39,7 @@ from tianshou.policy import (
|
|||||||
BasePolicy,
|
BasePolicy,
|
||||||
DDPGPolicy,
|
DDPGPolicy,
|
||||||
DQNPolicy,
|
DQNPolicy,
|
||||||
|
PGPolicy,
|
||||||
PPOPolicy,
|
PPOPolicy,
|
||||||
SACPolicy,
|
SACPolicy,
|
||||||
TD3Policy,
|
TD3Policy,
|
||||||
@ -355,6 +356,41 @@ class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
|||||||
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
||||||
|
|
||||||
|
|
||||||
|
class PGAgentFactory(OnpolicyAgentFactory, _ActorMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: PGParams,
|
||||||
|
sampling_config: SamplingConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
):
|
||||||
|
super().__init__(sampling_config, optim_factory)
|
||||||
|
_ActorMixin.__init__(self, actor_factory, optim_factory)
|
||||||
|
self.params = params
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
|
def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy:
|
||||||
|
actor = self.actor_factory.create_module_opt(
|
||||||
|
envs, device, self.optim_factory, self.params.lr,
|
||||||
|
)
|
||||||
|
kwargs = self.params.create_kwargs(
|
||||||
|
ParamTransformerData(
|
||||||
|
envs=envs,
|
||||||
|
device=device,
|
||||||
|
optim=actor.optim,
|
||||||
|
optim_factory=self.optim_factory,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return PGPolicy(
|
||||||
|
actor=actor.module,
|
||||||
|
optim=actor.optim,
|
||||||
|
action_space=envs.get_action_space(),
|
||||||
|
observation_space=envs.get_observation_space(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ActorCriticAgentFactory(
|
class ActorCriticAgentFactory(
|
||||||
Generic[TParams, TPolicy],
|
Generic[TParams, TPolicy],
|
||||||
OnpolicyAgentFactory,
|
OnpolicyAgentFactory,
|
||||||
|
@ -14,6 +14,7 @@ from tianshou.highlevel.agent import (
|
|||||||
AgentFactory,
|
AgentFactory,
|
||||||
DDPGAgentFactory,
|
DDPGAgentFactory,
|
||||||
DQNAgentFactory,
|
DQNAgentFactory,
|
||||||
|
PGAgentFactory,
|
||||||
PPOAgentFactory,
|
PPOAgentFactory,
|
||||||
SACAgentFactory,
|
SACAgentFactory,
|
||||||
TD3AgentFactory,
|
TD3AgentFactory,
|
||||||
@ -32,6 +33,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
A2CParams,
|
A2CParams,
|
||||||
DDPGParams,
|
DDPGParams,
|
||||||
DQNParams,
|
DQNParams,
|
||||||
|
PGParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
SACParams,
|
SACParams,
|
||||||
TD3Params,
|
TD3Params,
|
||||||
@ -280,7 +282,9 @@ class _BuilderMixinActorFactory:
|
|||||||
|
|
||||||
|
|
||||||
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
||||||
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
"""Specialization of the actor mixin where, in the continuous case, the actor component outputs
|
||||||
|
Gaussian distribution parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(ContinuousActorType.GAUSSIAN)
|
super().__init__(ContinuousActorType.GAUSSIAN)
|
||||||
@ -395,6 +399,35 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class PGExperimentBuilder(
|
||||||
|
ExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig | None = None,
|
||||||
|
sampling_config: SamplingConfig | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
|
self._params: A2CParams = A2CParams()
|
||||||
|
self._env_config = None
|
||||||
|
|
||||||
|
def with_pg_params(self, params: PGParams) -> Self:
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
return PGAgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_actor_factory(),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class A2CExperimentBuilder(
|
class A2CExperimentBuilder(
|
||||||
ExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
|
@ -7,6 +7,8 @@ from torch import nn
|
|||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||||
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||||
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import BaseActor, Net
|
from tianshou.utils.net.common import BaseActor, Net
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
@ -23,6 +25,21 @@ class ActorFactory(ToStringMixin, ABC):
|
|||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def create_module_opt(
|
||||||
|
self, envs: Environments, device: TDevice, optim_factory: OptimizerFactory, lr: float,
|
||||||
|
) -> ModuleOpt:
|
||||||
|
"""Creates the actor module along with its optimizer for the given learning rate.
|
||||||
|
|
||||||
|
:param envs: the environments
|
||||||
|
:param device: the torch device
|
||||||
|
:param optim_factory: the optimizer factory
|
||||||
|
:param lr: the learning rate
|
||||||
|
:return: a container with the actor module and its optimizer
|
||||||
|
"""
|
||||||
|
module = self.create_module(envs, device)
|
||||||
|
optim = optim_factory.create_optimizer(module, lr)
|
||||||
|
return ModuleOpt(module, optim)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_linear(actor: torch.nn.Module) -> None:
|
def _init_linear(actor: torch.nn.Module) -> None:
|
||||||
"""Initializes linear layers of an actor module using default mechanisms.
|
"""Initializes linear layers of an actor module using default mechanisms.
|
||||||
@ -154,3 +171,14 @@ class ActorFactoryDiscreteNet(ActorFactory):
|
|||||||
hidden_sizes=(),
|
hidden_sizes=(),
|
||||||
device=device,
|
device=device,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
class ActorModuleOptFactory(ToStringMixin):
|
||||||
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||||
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
|
opt = self.optim_factory.create_optimizer(actor, lr)
|
||||||
|
return ModuleOpt(actor, opt)
|
||||||
|
@ -5,6 +5,8 @@ from torch import nn
|
|||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||||
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||||
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
@ -78,3 +80,20 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
|||||||
critic = discrete.Critic(net_c, device=device).to(device)
|
critic = discrete.Critic(net_c, device=device).to(device)
|
||||||
init_linear_orthogonal(critic)
|
init_linear_orthogonal(critic)
|
||||||
return critic
|
return critic
|
||||||
|
|
||||||
|
|
||||||
|
class CriticModuleOptFactory(ToStringMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
use_action: bool,
|
||||||
|
):
|
||||||
|
self.critic_factory = critic_factory
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
self.use_action = use_action
|
||||||
|
|
||||||
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
||||||
|
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
||||||
|
opt = self.optim_factory.create_optimizer(critic, lr)
|
||||||
|
return ModuleOpt(critic, opt)
|
||||||
|
@ -2,13 +2,7 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
|
||||||
from tianshou.highlevel.module.actor import ActorFactory
|
|
||||||
from tianshou.highlevel.module.core import TDevice
|
|
||||||
from tianshou.highlevel.module.critic import CriticFactory
|
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
|
||||||
from tianshou.utils.net.common import ActorCritic
|
from tianshou.utils.net.common import ActorCritic
|
||||||
from tianshou.utils.string import ToStringMixin
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -29,31 +23,3 @@ class ActorCriticModuleOpt:
|
|||||||
@property
|
@property
|
||||||
def critic(self) -> torch.nn.Module:
|
def critic(self) -> torch.nn.Module:
|
||||||
return self.actor_critic_module.critic
|
return self.actor_critic_module.critic
|
||||||
|
|
||||||
|
|
||||||
class ActorModuleOptFactory(ToStringMixin):
|
|
||||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
|
||||||
self.actor_factory = actor_factory
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
|
|
||||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
|
||||||
opt = self.optim_factory.create_optimizer(actor, lr)
|
|
||||||
return ModuleOpt(actor, opt)
|
|
||||||
|
|
||||||
|
|
||||||
class CriticModuleOptFactory(ToStringMixin):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optim_factory: OptimizerFactory,
|
|
||||||
use_action: bool,
|
|
||||||
):
|
|
||||||
self.critic_factory = critic_factory
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
self.use_action = use_action
|
|
||||||
|
|
||||||
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
||||||
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
|
||||||
opt = self.optim_factory.create_optimizer(critic, lr)
|
|
||||||
return ModuleOpt(critic, opt)
|
|
||||||
|
@ -15,10 +15,10 @@ from tianshou.highlevel.params.dist_fn import (
|
|||||||
DistributionFunctionFactory,
|
DistributionFunctionFactory,
|
||||||
DistributionFunctionFactoryDefault,
|
DistributionFunctionFactoryDefault,
|
||||||
)
|
)
|
||||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
|
||||||
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||||
from tianshou.highlevel.params.noise import NoiseFactory
|
from tianshou.highlevel.params.noise import NoiseFactory
|
||||||
|
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||||
from tianshou.utils import MultipleLRSchedulers
|
from tianshou.utils import MultipleLRSchedulers
|
||||||
|
|
||||||
|
|
||||||
@ -277,42 +277,34 @@ class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGParams(Params):
|
class PGParams(Params, ParamsMixinLearningRateWithScheduler):
|
||||||
"""Config of general policy-gradient algorithms."""
|
|
||||||
|
|
||||||
discount_factor: float = 0.99
|
discount_factor: float = 0.99
|
||||||
reward_normalization: bool = False
|
reward_normalization: bool = False
|
||||||
deterministic_eval: bool = False
|
deterministic_eval: bool = False
|
||||||
action_scaling: bool | Literal["default"] = "default"
|
action_scaling: bool | Literal["default"] = "default"
|
||||||
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
|
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
|
||||||
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
|
||||||
transformers = super()._get_param_transformers()
|
|
||||||
transformers.append(ParamTransformerActionScaling("action_scaling"))
|
|
||||||
return transformers
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class A2CParams(PGParams, ParamsMixinLearningRateWithScheduler):
|
|
||||||
vf_coef: float = 0.5
|
|
||||||
ent_coef: float = 0.01
|
|
||||||
max_grad_norm: float | None = None
|
|
||||||
gae_lambda: float = 0.95
|
|
||||||
max_batchsize: int = 256
|
|
||||||
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
|
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
||||||
|
transformers.append(ParamTransformerActionScaling("action_scaling"))
|
||||||
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
|
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
|
||||||
return transformers
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PPOParams(A2CParams):
|
class A2CParams(PGParams):
|
||||||
"""PPO specific config."""
|
vf_coef: float = 0.5
|
||||||
|
ent_coef: float = 0.01
|
||||||
|
max_grad_norm: float | None = None
|
||||||
|
gae_lambda: float = 0.95
|
||||||
|
max_batchsize: int = 256
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PPOParams(A2CParams):
|
||||||
eps_clip: float = 0.2
|
eps_clip: float = 0.2
|
||||||
dual_clip: float | None = None
|
dual_clip: float | None = None
|
||||||
value_clip: bool = False
|
value_clip: bool = False
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Literal, cast, TypeAlias, Callable
|
from collections.abc import Callable
|
||||||
|
from typing import Any, Literal, TypeAlias, cast
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
Loading…
x
Reference in New Issue
Block a user