134 lines
4.9 KiB
Python
134 lines
4.9 KiB
Python
|
import os
|
||
|
from dataclasses import asdict, dataclass
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from tianshou.highlevel.experiment import Experiment
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class RLiableExperimentResult:
|
||
|
"""The result of an experiment that can be used with the rliable library."""
|
||
|
|
||
|
exp_dir: str
|
||
|
"""The base directory where each sub-directory contains the results of one experiment run."""
|
||
|
|
||
|
test_episode_returns_RE: np.ndarray
|
||
|
"""The test episodes for each run of the experiment where each row corresponds to one run."""
|
||
|
|
||
|
env_steps_E: np.ndarray
|
||
|
"""The number of environment steps at which the test episodes were evaluated."""
|
||
|
|
||
|
@classmethod
|
||
|
def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult":
|
||
|
"""Load the experiment result from disk.
|
||
|
|
||
|
:param exp_dir: The directory from where the experiment results are restored.
|
||
|
"""
|
||
|
test_episode_returns = []
|
||
|
test_data = None
|
||
|
|
||
|
for entry in os.scandir(exp_dir):
|
||
|
if entry.name.startswith(".") or not entry.is_dir():
|
||
|
continue
|
||
|
|
||
|
exp = Experiment.from_directory(entry.path)
|
||
|
logger = exp.logger_factory.create_logger(
|
||
|
entry.path,
|
||
|
entry.name,
|
||
|
None,
|
||
|
asdict(exp.config),
|
||
|
)
|
||
|
data = logger.restore_logged_data(entry.path)
|
||
|
|
||
|
test_data = data["test"]
|
||
|
|
||
|
test_episode_returns.append(test_data["returns_stat"]["mean"])
|
||
|
|
||
|
if test_data is None:
|
||
|
raise ValueError(f"No experiment data found in {exp_dir}.")
|
||
|
|
||
|
env_step = test_data["env_step"]
|
||
|
|
||
|
return cls(
|
||
|
test_episode_returns_RE=np.array(test_episode_returns),
|
||
|
env_steps_E=np.array(env_step),
|
||
|
exp_dir=exp_dir,
|
||
|
)
|
||
|
|
||
|
def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray):
|
||
|
"""Return the data in the format expected by the rliable library.
|
||
|
|
||
|
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
||
|
is set to the experiment dir.
|
||
|
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
|
||
|
from the minimum and maximum test episode returns.
|
||
|
|
||
|
:return: A tuple score_dict, env_steps, and score_thresholds.
|
||
|
"""
|
||
|
if score_thresholds is None:
|
||
|
score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101)
|
||
|
|
||
|
if algo_name is None:
|
||
|
algo_name = os.path.basename(self.exp_dir)
|
||
|
|
||
|
score_dict = {algo_name: self.test_episode_returns_RE}
|
||
|
|
||
|
return score_dict, self.env_steps_E, score_thresholds
|
||
|
|
||
|
def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False):
|
||
|
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.
|
||
|
|
||
|
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
||
|
is set to the experiment dir.
|
||
|
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
|
||
|
from the minimum and maximum test episode returns.
|
||
|
|
||
|
:return: The created figures and axes.
|
||
|
"""
|
||
|
import matplotlib.pyplot as plt
|
||
|
import scipy.stats as sst
|
||
|
from rliable import library as rly
|
||
|
from rliable import plot_utils
|
||
|
|
||
|
score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds)
|
||
|
|
||
|
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
||
|
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)
|
||
|
|
||
|
# Plot IQM sample efficiency curve
|
||
|
fig1, ax1 = plt.subplots(ncols=1, figsize=(7, 5))
|
||
|
plot_utils.plot_sample_efficiency_curve(
|
||
|
env_steps,
|
||
|
iqm_scores,
|
||
|
iqm_cis,
|
||
|
algorithms=None,
|
||
|
xlabel=r"Number of env steps",
|
||
|
ylabel="IQM episode return",
|
||
|
ax=ax1,
|
||
|
)
|
||
|
|
||
|
if save_figure:
|
||
|
plt.savefig(os.path.join(self.exp_dir, "iqm_sample_efficiency_curve.png"))
|
||
|
|
||
|
final_score_dict = {algo: returns[:, [-1]] for algo, returns in score_dict.items()}
|
||
|
score_distributions, score_distributions_cis = rly.create_performance_profile(
|
||
|
final_score_dict,
|
||
|
score_thresholds,
|
||
|
)
|
||
|
|
||
|
# Plot score distributions
|
||
|
fig2, ax2 = plt.subplots(ncols=1, figsize=(7, 5))
|
||
|
plot_utils.plot_performance_profiles(
|
||
|
score_distributions,
|
||
|
score_thresholds,
|
||
|
performance_profile_cis=score_distributions_cis,
|
||
|
xlabel=r"Episode return $(\tau)$",
|
||
|
ax=ax2,
|
||
|
)
|
||
|
|
||
|
if save_figure:
|
||
|
plt.savefig(os.path.join(self.exp_dir, "performance_profile.png"))
|
||
|
|
||
|
return fig1, ax1, fig2, ax2
|