add mujoco example with multiple runs and performance plots

This commit is contained in:
Maximilian Huettenrauch 2024-03-12 11:44:48 +01:00
parent 5762d2c2e0
commit 6c1bd85521
2 changed files with 194 additions and 0 deletions

View File

@ -0,0 +1,114 @@
#!/usr/bin/env python3
import os
from collections.abc import Sequence
from functools import partial
from typing import Literal
import torch
from examples.mujoco.mujoco_env import MujocoEnvFactory
from examples.mujoco.tools import eval_results, RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v4",
num_experiments: int = 5,
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 3e-4,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 10,
batch_size: int = 64,
training_num: int = 10,
test_num: int = 10,
rew_norm: bool = True,
vf_coef: float = 0.25,
ent_coef: float = 0.0,
gae_lambda: float = 0.95,
bound_action_method: Literal["clip", "tanh"] | None = "clip",
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
norm_adv: bool = False,
recompute_adv: bool = True,
) -> str:
log_name = os.path.join("log", task, "ppo", datetime_tag())
experiment_config.persistence_base_dir = log_name
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
experiments = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ppo_params(
PPOParams(
discount_factor=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
reward_normalization=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
advantage_normalization=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_advantage=recompute_adv,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build_default_seeded_experiments(num_experiments)
)
for experiment_name, experiment in experiments.items():
experiment.run(experiment_name)
return log_name
def eval_experiments(log_dir: str):
results = RLiableExperimentResult.load_from_disk(log_dir, 'PPO', None)
eval_results(results)
if __name__ == "__main__":
# logging.run_cli(main)
experiment_config = ExperimentConfig(watch=False)
log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
# log_dir = <path/to/exp>
eval_experiments(log_dir)

View File

@ -5,11 +5,91 @@ import csv
import os
import re
from collections import defaultdict
from dataclasses import dataclass, asdict
import numpy as np
import tqdm
from tensorboard.backend.event_processing import event_accumulator
from tianshou.highlevel.experiment import Experiment
@dataclass
class RLiableExperimentResult:
exp_dir: str
algorithms: list[str]
score_dict: dict[str, np.ndarray] # (n_runs x n_epochs + 1)
env_steps: np.ndarray # (n_epochs + 1)
score_thresholds: np.ndarray
@staticmethod
def load_from_disk(exp_dir: str, algo_name: str, score_thresholds: np.ndarray | None):
"""Load the experiment result from disk.
:param exp_dir: The directory from where the experiment results are restored.
:param algo_name: The name of the algorithm used in the figure legend.
:param score_thresholds: The thresholds used to create the performance profile.
If None, it will be created from the test episode returns.
"""
test_episode_returns = []
for entry in os.scandir(exp_dir):
if entry.name.startswith('.'):
continue
exp = Experiment.from_directory(entry.path)
logger = exp.logger_factory.create_logger(entry.path, entry.name, None, asdict(exp.config))
data = logger.restore_logged_data(entry.path)
test_data = data['test']
test_episode_returns.append(test_data['returns_stat']['mean'])
env_step = test_data['env_step']
if score_thresholds is None:
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
return RLiableExperimentResult(algorithms=[algo_name],
score_dict={algo_name: np.array(test_episode_returns)},
env_steps=np.array(env_step),
score_thresholds=score_thresholds,
exp_dir=exp_dir)
def eval_results(results: RLiableExperimentResult):
import matplotlib.pyplot as plt
import scipy.stats as sst
import seaborn as sns
from rliable import library as rly
from rliable import plot_utils
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates(
results.score_dict, iqm, reps=50000)
# Plot IQM sample efficiency curve
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_sample_efficiency_curve(
results.env_steps, iqm_scores, iqm_cis, algorithms=results.algorithms,
xlabel=r'Number of env steps',
ylabel='IQM episode return',
ax=ax)
plt.savefig(os.path.join(results.exp_dir, 'iqm_sample_efficiency_curve.png'))
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
score_distributions, score_distributions_cis = rly.create_performance_profile(
final_score_dict, results.score_thresholds)
# Plot score distributions
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_performance_profiles(
score_distributions, results.score_thresholds,
performance_profile_cis=score_distributions_cis,
colors=dict(zip(results.algorithms, sns.color_palette('colorblind'))),
xlabel=r'Episode return $(\tau)$',
ax=ax)
plt.savefig(os.path.join(results.exp_dir, 'performance_profile.png'))
def find_all_files(root_dir, pattern):
"""Find all files under root_dir according to relative pattern."""