restructured and moved RLiableExperimentResult

This commit is contained in:
Maximilian Huettenrauch 2024-03-27 12:03:31 +01:00
parent 18d8ffa576
commit 6d9b697efe
3 changed files with 142 additions and 119 deletions

View File

@ -7,7 +7,8 @@ from typing import Literal
import torch import torch
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from examples.mujoco.tools import RLiableExperimentResult, eval_results from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.evaluation import RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
ExperimentConfig, ExperimentConfig,
@ -25,12 +26,12 @@ from tianshou.utils.logging import datetime_tag
def main( def main(
experiment_config: ExperimentConfig, experiment_config: ExperimentConfig,
task: str = "Ant-v4", task: str = "Ant-v4",
num_experiments: int = 2, num_experiments: int = 5,
buffer_size: int = 4096, buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64), hidden_sizes: Sequence[int] = (64, 64),
lr: float = 3e-4, lr: float = 3e-4,
gamma: float = 0.99, gamma: float = 0.99,
epoch: int = 1, epoch: int = 100,
step_per_epoch: int = 30000, step_per_epoch: int = 30000,
step_per_collect: int = 2048, step_per_collect: int = 2048,
repeat_per_collect: int = 10, repeat_per_collect: int = 10,
@ -56,6 +57,7 @@ def main(
""" """
log_name = os.path.join("log", task, "ppo", datetime_tag()) log_name = os.path.join("log", task, "ppo", datetime_tag())
experiment_config.persistence_base_dir = log_name experiment_config.persistence_base_dir = log_name
experiment_config.watch = False
sampling_config = SamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,
@ -73,6 +75,7 @@ def main(
train_seed=sampling_config.train_seed, train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed, test_seed=sampling_config.test_seed,
obs_norm=True, obs_norm=True,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
) )
experiments = ( experiments = (
@ -110,8 +113,9 @@ def main(
def eval_experiments(log_dir: str): def eval_experiments(log_dir: str):
results = RLiableExperimentResult.load_from_disk(log_dir, "PPO") """Evaluate the experiments in the given log directory using the rliable API."""
eval_results(results, save_figure=True) rliable_result = RLiableExperimentResult.load_from_disk(log_dir)
rliable_result.eval_results(save_figure=True)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -5,125 +5,11 @@ import csv
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass
import numpy as np import numpy as np
import tqdm import tqdm
from tensorboard.backend.event_processing import event_accumulator from tensorboard.backend.event_processing import event_accumulator
from tianshou.highlevel.experiment import Experiment
@dataclass
class RLiableExperimentResult:
exp_dir: str
algorithms: list[str]
score_dict: dict[str, np.ndarray] # (n_runs x n_epochs + 1)
env_steps: np.ndarray # (n_epochs + 1)
score_thresholds: np.ndarray
@staticmethod
def load_from_disk(exp_dir: str, algo_name: str, score_thresholds: np.ndarray | None = None):
"""Load the experiment result from disk.
:param exp_dir: The directory from where the experiment results are restored.
:param algo_name: The name of the algorithm used in the figure legend.
:param score_thresholds: The thresholds used to create the performance profile.
If None, it will be created from the test episode returns.
"""
test_episode_returns = []
for entry in os.scandir(exp_dir):
if entry.name.startswith(".") or not entry.is_dir():
continue
exp = Experiment.from_directory(entry.path)
logger = exp.logger_factory.create_logger(
entry.path,
entry.name,
None,
asdict(exp.config),
)
data = logger.restore_logged_data(entry.path)
test_data = data["test"]
test_episode_returns.append(test_data["returns_stat"]["mean"])
env_step = test_data["env_step"]
if score_thresholds is None:
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
return RLiableExperimentResult(
algorithms=[algo_name],
score_dict={algo_name: np.array(test_episode_returns)},
env_steps=np.array(env_step),
score_thresholds=score_thresholds,
exp_dir=exp_dir,
)
def eval_results(results: RLiableExperimentResult, save_figure=False):
"""Evaluate the results of an experiment and create the performance profile and sample efficiency curve.
:param results: The results of the experiment. Needs to be compatible with the rliable API. This can be achieved by
calling the method `load_from_disk` from the RLiableExperimentResult class.
:param save_figure: Whether to save the figures as png to the experiment directory.
:return: The axes of the created figures.
"""
import matplotlib.pyplot as plt
import scipy.stats as sst
import seaborn as sns
from rliable import library as rly
from rliable import plot_utils
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates(results.score_dict, iqm, reps=50000)
# Plot IQM sample efficiency curve
fig, ax1 = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_sample_efficiency_curve(
results.env_steps,
iqm_scores,
iqm_cis,
algorithms=results.algorithms,
xlabel=r"Number of env steps",
ylabel="IQM episode return",
ax=ax1,
)
if save_figure:
plt.savefig(os.path.join(results.exp_dir, "iqm_sample_efficiency_curve.png"))
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
score_distributions, score_distributions_cis = rly.create_performance_profile(
final_score_dict,
results.score_thresholds,
)
# Plot score distributions
fig, ax2 = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_performance_profiles(
score_distributions,
results.score_thresholds,
performance_profile_cis=score_distributions_cis,
colors=dict(
zip(
results.algorithms,
sns.color_palette("colorblind", n_colors=len(results.algorithms)),
strict=True,
),
),
xlabel=r"Episode return $(\tau)$",
ax=ax2,
)
if save_figure:
plt.savefig(os.path.join(results.exp_dir, "performance_profile.png"))
return ax1, ax2
def find_all_files(root_dir, pattern): def find_all_files(root_dir, pattern):
"""Find all files under root_dir according to relative pattern.""" """Find all files under root_dir according to relative pattern."""

View File

@ -0,0 +1,133 @@
import os
from dataclasses import asdict, dataclass
import numpy as np
from tianshou.highlevel.experiment import Experiment
@dataclass
class RLiableExperimentResult:
"""The result of an experiment that can be used with the rliable library."""
exp_dir: str
"""The base directory where each sub-directory contains the results of one experiment run."""
test_episode_returns_RE: np.ndarray
"""The test episodes for each run of the experiment where each row corresponds to one run."""
env_steps_E: np.ndarray
"""The number of environment steps at which the test episodes were evaluated."""
@classmethod
def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult":
"""Load the experiment result from disk.
:param exp_dir: The directory from where the experiment results are restored.
"""
test_episode_returns = []
test_data = None
for entry in os.scandir(exp_dir):
if entry.name.startswith(".") or not entry.is_dir():
continue
exp = Experiment.from_directory(entry.path)
logger = exp.logger_factory.create_logger(
entry.path,
entry.name,
None,
asdict(exp.config),
)
data = logger.restore_logged_data(entry.path)
test_data = data["test"]
test_episode_returns.append(test_data["returns_stat"]["mean"])
if test_data is None:
raise ValueError(f"No experiment data found in {exp_dir}.")
env_step = test_data["env_step"]
return cls(
test_episode_returns_RE=np.array(test_episode_returns),
env_steps_E=np.array(env_step),
exp_dir=exp_dir,
)
def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray):
"""Return the data in the format expected by the rliable library.
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
is set to the experiment dir.
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
from the minimum and maximum test episode returns.
:return: A tuple score_dict, env_steps, and score_thresholds.
"""
if score_thresholds is None:
score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101)
if algo_name is None:
algo_name = os.path.basename(self.exp_dir)
score_dict = {algo_name: self.test_episode_returns_RE}
return score_dict, self.env_steps_E, score_thresholds
def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False):
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
is set to the experiment dir.
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
from the minimum and maximum test episode returns.
:return: The created figures and axes.
"""
import matplotlib.pyplot as plt
import scipy.stats as sst
from rliable import library as rly
from rliable import plot_utils
score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds)
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)
# Plot IQM sample efficiency curve
fig1, ax1 = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_sample_efficiency_curve(
env_steps,
iqm_scores,
iqm_cis,
algorithms=None,
xlabel=r"Number of env steps",
ylabel="IQM episode return",
ax=ax1,
)
if save_figure:
plt.savefig(os.path.join(self.exp_dir, "iqm_sample_efficiency_curve.png"))
final_score_dict = {algo: returns[:, [-1]] for algo, returns in score_dict.items()}
score_distributions, score_distributions_cis = rly.create_performance_profile(
final_score_dict,
score_thresholds,
)
# Plot score distributions
fig2, ax2 = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_performance_profiles(
score_distributions,
score_thresholds,
performance_profile_cis=score_distributions_cis,
xlabel=r"Episode return $(\tau)$",
ax=ax2,
)
if save_figure:
plt.savefig(os.path.join(self.exp_dir, "performance_profile.png"))
return fig1, ax1, fig2, ax2