Tianshou/examples/mujoco/mujoco_ppo_hl.py
Dominik Jain 78b6dd1f49 Adapt class naming scheme
* Use prefix convention (subclasses have superclass names as prefix) to
  facilitate discoverability of relevant classes via IDE autocompletion
* Use dual naming, adding an alternative concise name that omits the
  precise OO semantics and retains only the essential part of the name
  (which can be more pleasing to users not accustomed to
  convoluted OO naming)
2023-10-18 20:44:16 +02:00

98 lines
2.9 KiB
Python

#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from typing import Literal
from jsonargparse import CLI
from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
PPOExperimentBuilder,
RLExperimentConfig,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
def main(
experiment_config: RLExperimentConfig,
task: str = "Ant-v4",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 3e-4,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 10,
batch_size: int = 64,
training_num: int = 64,
test_num: int = 10,
rew_norm: bool = True,
vf_coef: float = 0.25,
ent_coef: float = 0.0,
gae_lambda: float = 0.95,
bound_action_method: Literal["clip", "tanh"] | None = "clip",
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
norm_adv: bool = False,
recompute_adv: bool = True,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
def dist_fn(*logits):
return Independent(Normal(*logits), 1)
experiment = (
PPOExperimentBuilder(experiment_config, env_factory, sampling_config, dist_fn)
.with_params(
PPOParams(
discount_factor=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
reward_normalization=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
advantage_normalization=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_advantage=recompute_adv,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
),
)
.with_actor_factory_default(hidden_sizes)
.with_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)