Support NPG in high-level API and add example mujoco_npg_hl
This commit is contained in:
parent
73a6d15eee
commit
383a4a6083
87
examples/mujoco/mujoco_npg_hl.py
Normal file
87
examples/mujoco/mujoco_npg_hl.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
ExperimentConfig,
|
||||||
|
NPGExperimentBuilder,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.dist_fn import (
|
||||||
|
DistributionFunctionFactoryIndependentGaussians,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
|
from tianshou.highlevel.params.policy_params import NPGParams
|
||||||
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
|
task: str = "Ant-v3",
|
||||||
|
buffer_size: int = 4096,
|
||||||
|
hidden_sizes: Sequence[int] = (64, 64),
|
||||||
|
lr: float = 1e-3,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
epoch: int = 100,
|
||||||
|
step_per_epoch: int = 30000,
|
||||||
|
step_per_collect: int = 1024,
|
||||||
|
repeat_per_collect: int = 1,
|
||||||
|
batch_size: int = 99999,
|
||||||
|
training_num: int = 16,
|
||||||
|
test_num: int = 10,
|
||||||
|
rew_norm: bool = True,
|
||||||
|
gae_lambda: float = 0.95,
|
||||||
|
bound_action_method: Literal["clip", "tanh"] = "clip",
|
||||||
|
lr_decay: bool = True,
|
||||||
|
norm_adv: bool = True,
|
||||||
|
optim_critic_iters: int = 20,
|
||||||
|
actor_step_size: float = 0.1,
|
||||||
|
):
|
||||||
|
log_name = os.path.join(task, "npg", str(experiment_config.seed), datetime_tag())
|
||||||
|
|
||||||
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
repeat_per_collect=repeat_per_collect,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
|
experiment = (
|
||||||
|
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
.with_npg_params(
|
||||||
|
NPGParams(
|
||||||
|
discount_factor=gamma,
|
||||||
|
gae_lambda=gae_lambda,
|
||||||
|
action_bound_method=bound_action_method,
|
||||||
|
reward_normalization=rew_norm,
|
||||||
|
advantage_normalization=norm_adv,
|
||||||
|
optim_critic_iters=optim_critic_iters,
|
||||||
|
actor_step_size=actor_step_size,
|
||||||
|
lr=lr,
|
||||||
|
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||||
|
if lr_decay
|
||||||
|
else None,
|
||||||
|
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_actor_factory_default(hidden_sizes, continuous_unbounded=True)
|
||||||
|
.with_critic_factory_default(hidden_sizes)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.run_main(lambda: CLI(main))
|
@ -23,6 +23,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
A2CParams,
|
A2CParams,
|
||||||
DDPGParams,
|
DDPGParams,
|
||||||
DQNParams,
|
DQNParams,
|
||||||
|
NPGParams,
|
||||||
Params,
|
Params,
|
||||||
ParamTransformerData,
|
ParamTransformerData,
|
||||||
PGParams,
|
PGParams,
|
||||||
@ -37,6 +38,7 @@ from tianshou.policy import (
|
|||||||
BasePolicy,
|
BasePolicy,
|
||||||
DDPGPolicy,
|
DDPGPolicy,
|
||||||
DQNPolicy,
|
DQNPolicy,
|
||||||
|
NPGPolicy,
|
||||||
PGPolicy,
|
PGPolicy,
|
||||||
PPOPolicy,
|
PPOPolicy,
|
||||||
SACPolicy,
|
SACPolicy,
|
||||||
@ -429,6 +431,30 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
|||||||
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
|
|
||||||
|
|
||||||
|
class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: NPGParams,
|
||||||
|
sampling_config: SamplingConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optimizer_factory: OptimizerFactory,
|
||||||
|
critic_use_actor_module: bool,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
params,
|
||||||
|
sampling_config,
|
||||||
|
actor_factory,
|
||||||
|
critic_factory,
|
||||||
|
optimizer_factory,
|
||||||
|
NPGPolicy,
|
||||||
|
critic_use_actor_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
||||||
|
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
|
|
||||||
|
|
||||||
class DQNAgentFactory(OffpolicyAgentFactory):
|
class DQNAgentFactory(OffpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -14,6 +14,7 @@ from tianshou.highlevel.agent import (
|
|||||||
AgentFactory,
|
AgentFactory,
|
||||||
DDPGAgentFactory,
|
DDPGAgentFactory,
|
||||||
DQNAgentFactory,
|
DQNAgentFactory,
|
||||||
|
NPGAgentFactory,
|
||||||
PGAgentFactory,
|
PGAgentFactory,
|
||||||
PPOAgentFactory,
|
PPOAgentFactory,
|
||||||
SACAgentFactory,
|
SACAgentFactory,
|
||||||
@ -33,6 +34,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
A2CParams,
|
A2CParams,
|
||||||
DDPGParams,
|
DDPGParams,
|
||||||
DQNParams,
|
DQNParams,
|
||||||
|
NPGParams,
|
||||||
PGParams,
|
PGParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
SACParams,
|
SACParams,
|
||||||
@ -494,6 +496,38 @@ class PPOExperimentBuilder(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NPGExperimentBuilder(
|
||||||
|
ExperimentBuilder,
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig | None = None,
|
||||||
|
sampling_config: SamplingConfig | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
||||||
|
self._params: NPGParams = NPGParams()
|
||||||
|
|
||||||
|
def with_npg_params(self, params: NPGParams) -> Self:
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
return NPGAgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
self._get_actor_factory(),
|
||||||
|
self._get_critic_factory(0),
|
||||||
|
self._get_optim_factory(),
|
||||||
|
self._critic_use_actor_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DQNExperimentBuilder(
|
class DQNExperimentBuilder(
|
||||||
ExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory,
|
_BuilderMixinActorFactory,
|
||||||
|
@ -312,6 +312,15 @@ class PPOParams(A2CParams):
|
|||||||
recompute_advantage: bool = False
|
recompute_advantage: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NPGParams(PGParams):
|
||||||
|
optim_critic_iters: int = 5
|
||||||
|
actor_step_size: float = 0.5
|
||||||
|
advantage_normalization: bool = True
|
||||||
|
gae_lambda: float = 0.95
|
||||||
|
max_batchsize: int = 256
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
||||||
actor_lr: float = 1e-3
|
actor_lr: float = 1e-3
|
||||||
|
Loading…
x
Reference in New Issue
Block a user