add mujoco example with multiple runs and performance plots

This commit is contained in:
Maximilian Huettenrauch 2024-03-12 11:44:48 +01:00
parent 5762d2c2e0
commit 6c1bd85521
2 changed files with 194 additions and 0 deletions

View File

@ -0,0 +1,114 @@
#!/usr/bin/env python3
import os
from collections.abc import Sequence
from functools import partial
from typing import Literal
import torch
from examples.mujoco.mujoco_env import MujocoEnvFactory
from examples.mujoco.tools import eval_results, RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v4",
num_experiments: int = 5,
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 3e-4,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 10,
batch_size: int = 64,
training_num: int = 10,
test_num: int = 10,
rew_norm: bool = True,
vf_coef: float = 0.25,
ent_coef: float = 0.0,
gae_lambda: float = 0.95,
bound_action_method: Literal["clip", "tanh"] | None = "clip",
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
norm_adv: bool = False,
recompute_adv: bool = True,
) -> str:
log_name = os.path.join("log", task, "ppo", datetime_tag())
experiment_config.persistence_base_dir = log_name
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
experiments = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ppo_params(
PPOParams(
discount_factor=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
reward_normalization=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
advantage_normalization=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_advantage=recompute_adv,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build_default_seeded_experiments(num_experiments)
)
for experiment_name, experiment in experiments.items():
experiment.run(experiment_name)
return log_name
def eval_experiments(log_dir: str):
results = RLiableExperimentResult.load_from_disk(log_dir, 'PPO', None)
eval_results(results)
if __name__ == "__main__":
# logging.run_cli(main)
experiment_config = ExperimentConfig(watch=False)
log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
# log_dir = <path/to/exp>
eval_experiments(log_dir)

View File

@ -5,11 +5,91 @@ import csv
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, asdict
import numpy as np import numpy as np
import tqdm import tqdm
from tensorboard.backend.event_processing import event_accumulator from tensorboard.backend.event_processing import event_accumulator
from tianshou.highlevel.experiment import Experiment
@dataclass
class RLiableExperimentResult:
exp_dir: str
algorithms: list[str]
score_dict: dict[str, np.ndarray] # (n_runs x n_epochs + 1)
env_steps: np.ndarray # (n_epochs + 1)
score_thresholds: np.ndarray
@staticmethod
def load_from_disk(exp_dir: str, algo_name: str, score_thresholds: np.ndarray | None):
"""Load the experiment result from disk.
:param exp_dir: The directory from where the experiment results are restored.
:param algo_name: The name of the algorithm used in the figure legend.
:param score_thresholds: The thresholds used to create the performance profile.
If None, it will be created from the test episode returns.
"""
test_episode_returns = []
for entry in os.scandir(exp_dir):
if entry.name.startswith('.'):
continue
exp = Experiment.from_directory(entry.path)
logger = exp.logger_factory.create_logger(entry.path, entry.name, None, asdict(exp.config))
data = logger.restore_logged_data(entry.path)
test_data = data['test']
test_episode_returns.append(test_data['returns_stat']['mean'])
env_step = test_data['env_step']
if score_thresholds is None:
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
return RLiableExperimentResult(algorithms=[algo_name],
score_dict={algo_name: np.array(test_episode_returns)},
env_steps=np.array(env_step),
score_thresholds=score_thresholds,
exp_dir=exp_dir)
def eval_results(results: RLiableExperimentResult):
import matplotlib.pyplot as plt
import scipy.stats as sst
import seaborn as sns
from rliable import library as rly
from rliable import plot_utils
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates(
results.score_dict, iqm, reps=50000)
# Plot IQM sample efficiency curve
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_sample_efficiency_curve(
results.env_steps, iqm_scores, iqm_cis, algorithms=results.algorithms,
xlabel=r'Number of env steps',
ylabel='IQM episode return',
ax=ax)
plt.savefig(os.path.join(results.exp_dir, 'iqm_sample_efficiency_curve.png'))
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
score_distributions, score_distributions_cis = rly.create_performance_profile(
final_score_dict, results.score_thresholds)
# Plot score distributions
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_performance_profiles(
score_distributions, results.score_thresholds,
performance_profile_cis=score_distributions_cis,
colors=dict(zip(results.algorithms, sns.color_palette('colorblind'))),
xlabel=r'Episode return $(\tau)$',
ax=ax)
plt.savefig(os.path.join(results.exp_dir, 'performance_profile.png'))
def find_all_files(root_dir, pattern): def find_all_files(root_dir, pattern):
"""Find all files under root_dir according to relative pattern.""" """Find all files under root_dir according to relative pattern."""