add mujoco example with multiple runs and performance plots
This commit is contained in:
parent
5762d2c2e0
commit
6c1bd85521
114
examples/mujoco/mujoco_ppo_hl_multi.py
Normal file
114
examples/mujoco/mujoco_ppo_hl_multi.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from functools import partial
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
|
from examples.mujoco.tools import eval_results, RLiableExperimentResult
|
||||||
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
ExperimentConfig,
|
||||||
|
PPOExperimentBuilder,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.dist_fn import (
|
||||||
|
DistributionFunctionFactoryIndependentGaussians,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
|
task: str = "Ant-v4",
|
||||||
|
num_experiments: int = 5,
|
||||||
|
buffer_size: int = 4096,
|
||||||
|
hidden_sizes: Sequence[int] = (64, 64),
|
||||||
|
lr: float = 3e-4,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
epoch: int = 100,
|
||||||
|
step_per_epoch: int = 30000,
|
||||||
|
step_per_collect: int = 2048,
|
||||||
|
repeat_per_collect: int = 10,
|
||||||
|
batch_size: int = 64,
|
||||||
|
training_num: int = 10,
|
||||||
|
test_num: int = 10,
|
||||||
|
rew_norm: bool = True,
|
||||||
|
vf_coef: float = 0.25,
|
||||||
|
ent_coef: float = 0.0,
|
||||||
|
gae_lambda: float = 0.95,
|
||||||
|
bound_action_method: Literal["clip", "tanh"] | None = "clip",
|
||||||
|
lr_decay: bool = True,
|
||||||
|
max_grad_norm: float = 0.5,
|
||||||
|
eps_clip: float = 0.2,
|
||||||
|
dual_clip: float | None = None,
|
||||||
|
value_clip: bool = False,
|
||||||
|
norm_adv: bool = False,
|
||||||
|
recompute_adv: bool = True,
|
||||||
|
) -> str:
|
||||||
|
log_name = os.path.join("log", task, "ppo", datetime_tag())
|
||||||
|
experiment_config.persistence_base_dir = log_name
|
||||||
|
|
||||||
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
repeat_per_collect=repeat_per_collect,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||||
|
|
||||||
|
experiments = (
|
||||||
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
.with_ppo_params(
|
||||||
|
PPOParams(
|
||||||
|
discount_factor=gamma,
|
||||||
|
gae_lambda=gae_lambda,
|
||||||
|
action_bound_method=bound_action_method,
|
||||||
|
reward_normalization=rew_norm,
|
||||||
|
ent_coef=ent_coef,
|
||||||
|
vf_coef=vf_coef,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
|
value_clip=value_clip,
|
||||||
|
advantage_normalization=norm_adv,
|
||||||
|
eps_clip=eps_clip,
|
||||||
|
dual_clip=dual_clip,
|
||||||
|
recompute_advantage=recompute_adv,
|
||||||
|
lr=lr,
|
||||||
|
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||||
|
if lr_decay
|
||||||
|
else None,
|
||||||
|
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||||
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
|
.build_default_seeded_experiments(num_experiments)
|
||||||
|
)
|
||||||
|
|
||||||
|
for experiment_name, experiment in experiments.items():
|
||||||
|
experiment.run(experiment_name)
|
||||||
|
|
||||||
|
return log_name
|
||||||
|
|
||||||
|
|
||||||
|
def eval_experiments(log_dir: str):
|
||||||
|
results = RLiableExperimentResult.load_from_disk(log_dir, 'PPO', None)
|
||||||
|
eval_results(results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# logging.run_cli(main)
|
||||||
|
experiment_config = ExperimentConfig(watch=False)
|
||||||
|
log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
|
||||||
|
# log_dir = <path/to/exp>
|
||||||
|
eval_experiments(log_dir)
|
@ -5,11 +5,91 @@ import csv
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensorboard.backend.event_processing import event_accumulator
|
from tensorboard.backend.event_processing import event_accumulator
|
||||||
|
|
||||||
|
from tianshou.highlevel.experiment import Experiment
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RLiableExperimentResult:
|
||||||
|
exp_dir: str
|
||||||
|
algorithms: list[str]
|
||||||
|
score_dict: dict[str, np.ndarray] # (n_runs x n_epochs + 1)
|
||||||
|
env_steps: np.ndarray # (n_epochs + 1)
|
||||||
|
score_thresholds: np.ndarray
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_from_disk(exp_dir: str, algo_name: str, score_thresholds: np.ndarray | None):
|
||||||
|
"""Load the experiment result from disk.
|
||||||
|
|
||||||
|
:param exp_dir: The directory from where the experiment results are restored.
|
||||||
|
:param algo_name: The name of the algorithm used in the figure legend.
|
||||||
|
:param score_thresholds: The thresholds used to create the performance profile.
|
||||||
|
If None, it will be created from the test episode returns.
|
||||||
|
"""
|
||||||
|
test_episode_returns = []
|
||||||
|
|
||||||
|
for entry in os.scandir(exp_dir):
|
||||||
|
if entry.name.startswith('.'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
exp = Experiment.from_directory(entry.path)
|
||||||
|
logger = exp.logger_factory.create_logger(entry.path, entry.name, None, asdict(exp.config))
|
||||||
|
data = logger.restore_logged_data(entry.path)
|
||||||
|
|
||||||
|
test_data = data['test']
|
||||||
|
|
||||||
|
test_episode_returns.append(test_data['returns_stat']['mean'])
|
||||||
|
env_step = test_data['env_step']
|
||||||
|
|
||||||
|
if score_thresholds is None:
|
||||||
|
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
|
||||||
|
|
||||||
|
return RLiableExperimentResult(algorithms=[algo_name],
|
||||||
|
score_dict={algo_name: np.array(test_episode_returns)},
|
||||||
|
env_steps=np.array(env_step),
|
||||||
|
score_thresholds=score_thresholds,
|
||||||
|
exp_dir=exp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def eval_results(results: RLiableExperimentResult):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import scipy.stats as sst
|
||||||
|
import seaborn as sns
|
||||||
|
from rliable import library as rly
|
||||||
|
from rliable import plot_utils
|
||||||
|
|
||||||
|
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
||||||
|
iqm_scores, iqm_cis = rly.get_interval_estimates(
|
||||||
|
results.score_dict, iqm, reps=50000)
|
||||||
|
|
||||||
|
# Plot IQM sample efficiency curve
|
||||||
|
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
|
||||||
|
plot_utils.plot_sample_efficiency_curve(
|
||||||
|
results.env_steps, iqm_scores, iqm_cis, algorithms=results.algorithms,
|
||||||
|
xlabel=r'Number of env steps',
|
||||||
|
ylabel='IQM episode return',
|
||||||
|
ax=ax)
|
||||||
|
plt.savefig(os.path.join(results.exp_dir, 'iqm_sample_efficiency_curve.png'))
|
||||||
|
|
||||||
|
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
|
||||||
|
score_distributions, score_distributions_cis = rly.create_performance_profile(
|
||||||
|
final_score_dict, results.score_thresholds)
|
||||||
|
|
||||||
|
# Plot score distributions
|
||||||
|
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
|
||||||
|
plot_utils.plot_performance_profiles(
|
||||||
|
score_distributions, results.score_thresholds,
|
||||||
|
performance_profile_cis=score_distributions_cis,
|
||||||
|
colors=dict(zip(results.algorithms, sns.color_palette('colorblind'))),
|
||||||
|
xlabel=r'Episode return $(\tau)$',
|
||||||
|
ax=ax)
|
||||||
|
plt.savefig(os.path.join(results.exp_dir, 'performance_profile.png'))
|
||||||
|
|
||||||
|
|
||||||
def find_all_files(root_dir, pattern):
|
def find_all_files(root_dir, pattern):
|
||||||
"""Find all files under root_dir according to relative pattern."""
|
"""Find all files under root_dir according to relative pattern."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user