From 6d9b697efe338c35509a85459923975b8173054d Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Wed, 27 Mar 2024 12:03:31 +0100 Subject: [PATCH] restructured and moved RLiableExperimentResult --- examples/mujoco/mujoco_ppo_hl_multi.py | 14 ++- examples/mujoco/tools.py | 114 --------------------- tianshou/highlevel/evaluation.py | 133 +++++++++++++++++++++++++ 3 files changed, 142 insertions(+), 119 deletions(-) create mode 100644 tianshou/highlevel/evaluation.py diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index c5edb1f..f25317f 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -7,7 +7,8 @@ from typing import Literal import torch from examples.mujoco.mujoco_env import MujocoEnvFactory -from examples.mujoco.tools import RLiableExperimentResult, eval_results +from tianshou.highlevel.env import VectorEnvType +from tianshou.highlevel.evaluation import RLiableExperimentResult from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, @@ -25,12 +26,12 @@ from tianshou.utils.logging import datetime_tag def main( experiment_config: ExperimentConfig, task: str = "Ant-v4", - num_experiments: int = 2, + num_experiments: int = 5, buffer_size: int = 4096, hidden_sizes: Sequence[int] = (64, 64), lr: float = 3e-4, gamma: float = 0.99, - epoch: int = 1, + epoch: int = 100, step_per_epoch: int = 30000, step_per_collect: int = 2048, repeat_per_collect: int = 10, @@ -56,6 +57,7 @@ def main( """ log_name = os.path.join("log", task, "ppo", datetime_tag()) experiment_config.persistence_base_dir = log_name + experiment_config.watch = False sampling_config = SamplingConfig( num_epochs=epoch, @@ -73,6 +75,7 @@ def main( train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True, + venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT ) experiments = ( @@ -110,8 +113,9 @@ def main( def eval_experiments(log_dir: str): - results = RLiableExperimentResult.load_from_disk(log_dir, "PPO") - eval_results(results, save_figure=True) + """Evaluate the experiments in the given log directory using the rliable API.""" + rliable_result = RLiableExperimentResult.load_from_disk(log_dir) + rliable_result.eval_results(save_figure=True) if __name__ == "__main__": diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py index ea402c8..be289e3 100755 --- a/examples/mujoco/tools.py +++ b/examples/mujoco/tools.py @@ -5,125 +5,11 @@ import csv import os import re from collections import defaultdict -from dataclasses import asdict, dataclass import numpy as np import tqdm from tensorboard.backend.event_processing import event_accumulator -from tianshou.highlevel.experiment import Experiment - - -@dataclass -class RLiableExperimentResult: - exp_dir: str - algorithms: list[str] - score_dict: dict[str, np.ndarray] # (n_runs x n_epochs + 1) - env_steps: np.ndarray # (n_epochs + 1) - score_thresholds: np.ndarray - - @staticmethod - def load_from_disk(exp_dir: str, algo_name: str, score_thresholds: np.ndarray | None = None): - """Load the experiment result from disk. - - :param exp_dir: The directory from where the experiment results are restored. - :param algo_name: The name of the algorithm used in the figure legend. - :param score_thresholds: The thresholds used to create the performance profile. - If None, it will be created from the test episode returns. - """ - test_episode_returns = [] - - for entry in os.scandir(exp_dir): - if entry.name.startswith(".") or not entry.is_dir(): - continue - - exp = Experiment.from_directory(entry.path) - logger = exp.logger_factory.create_logger( - entry.path, - entry.name, - None, - asdict(exp.config), - ) - data = logger.restore_logged_data(entry.path) - - test_data = data["test"] - - test_episode_returns.append(test_data["returns_stat"]["mean"]) - env_step = test_data["env_step"] - - if score_thresholds is None: - score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101) - - return RLiableExperimentResult( - algorithms=[algo_name], - score_dict={algo_name: np.array(test_episode_returns)}, - env_steps=np.array(env_step), - score_thresholds=score_thresholds, - exp_dir=exp_dir, - ) - - -def eval_results(results: RLiableExperimentResult, save_figure=False): - """Evaluate the results of an experiment and create the performance profile and sample efficiency curve. - - :param results: The results of the experiment. Needs to be compatible with the rliable API. This can be achieved by - calling the method `load_from_disk` from the RLiableExperimentResult class. - :param save_figure: Whether to save the figures as png to the experiment directory. - - :return: The axes of the created figures. - """ - import matplotlib.pyplot as plt - import scipy.stats as sst - import seaborn as sns - from rliable import library as rly - from rliable import plot_utils - - iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) - iqm_scores, iqm_cis = rly.get_interval_estimates(results.score_dict, iqm, reps=50000) - - # Plot IQM sample efficiency curve - fig, ax1 = plt.subplots(ncols=1, figsize=(7, 5)) - plot_utils.plot_sample_efficiency_curve( - results.env_steps, - iqm_scores, - iqm_cis, - algorithms=results.algorithms, - xlabel=r"Number of env steps", - ylabel="IQM episode return", - ax=ax1, - ) - - if save_figure: - plt.savefig(os.path.join(results.exp_dir, "iqm_sample_efficiency_curve.png")) - - final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()} - score_distributions, score_distributions_cis = rly.create_performance_profile( - final_score_dict, - results.score_thresholds, - ) - - # Plot score distributions - fig, ax2 = plt.subplots(ncols=1, figsize=(7, 5)) - plot_utils.plot_performance_profiles( - score_distributions, - results.score_thresholds, - performance_profile_cis=score_distributions_cis, - colors=dict( - zip( - results.algorithms, - sns.color_palette("colorblind", n_colors=len(results.algorithms)), - strict=True, - ), - ), - xlabel=r"Episode return $(\tau)$", - ax=ax2, - ) - - if save_figure: - plt.savefig(os.path.join(results.exp_dir, "performance_profile.png")) - - return ax1, ax2 - def find_all_files(root_dir, pattern): """Find all files under root_dir according to relative pattern.""" diff --git a/tianshou/highlevel/evaluation.py b/tianshou/highlevel/evaluation.py new file mode 100644 index 0000000..cb0a1e7 --- /dev/null +++ b/tianshou/highlevel/evaluation.py @@ -0,0 +1,133 @@ +import os +from dataclasses import asdict, dataclass + +import numpy as np + +from tianshou.highlevel.experiment import Experiment + + +@dataclass +class RLiableExperimentResult: + """The result of an experiment that can be used with the rliable library.""" + + exp_dir: str + """The base directory where each sub-directory contains the results of one experiment run.""" + + test_episode_returns_RE: np.ndarray + """The test episodes for each run of the experiment where each row corresponds to one run.""" + + env_steps_E: np.ndarray + """The number of environment steps at which the test episodes were evaluated.""" + + @classmethod + def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult": + """Load the experiment result from disk. + + :param exp_dir: The directory from where the experiment results are restored. + """ + test_episode_returns = [] + test_data = None + + for entry in os.scandir(exp_dir): + if entry.name.startswith(".") or not entry.is_dir(): + continue + + exp = Experiment.from_directory(entry.path) + logger = exp.logger_factory.create_logger( + entry.path, + entry.name, + None, + asdict(exp.config), + ) + data = logger.restore_logged_data(entry.path) + + test_data = data["test"] + + test_episode_returns.append(test_data["returns_stat"]["mean"]) + + if test_data is None: + raise ValueError(f"No experiment data found in {exp_dir}.") + + env_step = test_data["env_step"] + + return cls( + test_episode_returns_RE=np.array(test_episode_returns), + env_steps_E=np.array(env_step), + exp_dir=exp_dir, + ) + + def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray): + """Return the data in the format expected by the rliable library. + + :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm + is set to the experiment dir. + :param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred + from the minimum and maximum test episode returns. + + :return: A tuple score_dict, env_steps, and score_thresholds. + """ + if score_thresholds is None: + score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101) + + if algo_name is None: + algo_name = os.path.basename(self.exp_dir) + + score_dict = {algo_name: self.test_episode_returns_RE} + + return score_dict, self.env_steps_E, score_thresholds + + def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False): + """Evaluate the results of an experiment and create a sample efficiency curve and a performance profile. + + :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm + is set to the experiment dir. + :param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred + from the minimum and maximum test episode returns. + + :return: The created figures and axes. + """ + import matplotlib.pyplot as plt + import scipy.stats as sst + from rliable import library as rly + from rliable import plot_utils + + score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds) + + iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) + iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm) + + # Plot IQM sample efficiency curve + fig1, ax1 = plt.subplots(ncols=1, figsize=(7, 5)) + plot_utils.plot_sample_efficiency_curve( + env_steps, + iqm_scores, + iqm_cis, + algorithms=None, + xlabel=r"Number of env steps", + ylabel="IQM episode return", + ax=ax1, + ) + + if save_figure: + plt.savefig(os.path.join(self.exp_dir, "iqm_sample_efficiency_curve.png")) + + final_score_dict = {algo: returns[:, [-1]] for algo, returns in score_dict.items()} + score_distributions, score_distributions_cis = rly.create_performance_profile( + final_score_dict, + score_thresholds, + ) + + # Plot score distributions + fig2, ax2 = plt.subplots(ncols=1, figsize=(7, 5)) + plot_utils.plot_performance_profiles( + score_distributions, + score_thresholds, + performance_profile_cis=score_distributions_cis, + xlabel=r"Episode return $(\tau)$", + ax=ax2, + ) + + if save_figure: + plt.savefig(os.path.join(self.exp_dir, "performance_profile.png")) + + return fig1, ax1, fig2, ax2