Tianshou/tianshou/highlevel/evaluation.py
Maximilian Huettenrauch 9c645ff4a0 pleased the mypy gods
2024-03-27 15:37:19 +01:00

182 lines
6.3 KiB
Python

import os
from dataclasses import asdict, dataclass
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as sst
from rliable import library as rly
from rliable import plot_utils
from tianshou.highlevel.experiment import Experiment
@dataclass
class LoggedSummaryData:
mean: np.ndarray
std: np.ndarray
max: np.ndarray
min: np.ndarray
@dataclass
class LoggedCollectStats:
env_step: np.ndarray
n_collected_episodes: np.ndarray
n_collected_steps: np.ndarray
collect_time: np.ndarray
collect_speed: np.ndarray
returns_stat: LoggedSummaryData
lens_stat: LoggedSummaryData
@classmethod
def from_data_dict(cls, data: dict) -> "LoggedCollectStats":
return cls(
env_step=np.array(data["env_step"]),
n_collected_episodes=np.array(data["n_collected_episodes"]),
n_collected_steps=np.array(data["n_collected_steps"]),
collect_time=np.array(data["collect_time"]),
collect_speed=np.array(data["collect_speed"]),
returns_stat=LoggedSummaryData(**data["returns_stat"]),
lens_stat=LoggedSummaryData(**data["lens_stat"]),
)
@dataclass
class RLiableExperimentResult:
"""The result of an experiment that can be used with the rliable library."""
exp_dir: str
"""The base directory where each sub-directory contains the results of one experiment run."""
test_episode_returns_RE: np.ndarray
"""The test episodes for each run of the experiment where each row corresponds to one run."""
env_steps_E: np.ndarray
"""The number of environment steps at which the test episodes were evaluated."""
@classmethod
def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult":
"""Load the experiment result from disk.
:param exp_dir: The directory from where the experiment results are restored.
"""
test_episode_returns = []
env_step_at_test = None
for entry in os.scandir(exp_dir):
if entry.name.startswith(".") or not entry.is_dir():
continue
exp = Experiment.from_directory(entry.path)
logger = exp.logger_factory.create_logger(
entry.path,
entry.name,
None,
asdict(exp.config),
)
data = logger.restore_logged_data(entry.path)
if "test" not in data or not data["test"]:
continue
test_data = LoggedCollectStats.from_data_dict(data["test"])
test_episode_returns.append(test_data.returns_stat.mean)
env_step_at_test = test_data.env_step
if not test_episode_returns or env_step_at_test is None:
raise ValueError(f"No experiment data found in {exp_dir}.")
return cls(
test_episode_returns_RE=np.array(test_episode_returns),
env_steps_E=np.array(env_step_at_test),
exp_dir=exp_dir,
)
def _get_rliable_data(
self,
algo_name: str | None = None,
score_thresholds: np.ndarray | None = None,
) -> tuple[dict, np.ndarray, np.ndarray]:
"""Return the data in the format expected by the rliable library.
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
is set to the experiment dir.
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
from the minimum and maximum test episode returns.
:return: A tuple score_dict, env_steps, and score_thresholds.
"""
if score_thresholds is None:
score_thresholds = np.linspace(
np.min(self.test_episode_returns_RE),
np.max(self.test_episode_returns_RE),
101,
)
if algo_name is None:
algo_name = os.path.basename(self.exp_dir)
score_dict = {algo_name: self.test_episode_returns_RE}
return score_dict, self.env_steps_E, score_thresholds
def eval_results(
self,
algo_name: str | None = None,
score_thresholds: np.ndarray | None = None,
save_figure: bool = False,
) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]:
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
is set to the experiment dir.
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
from the minimum and maximum test episode returns.
:param save_figure: If True, the figures are saved to the experiment directory.
:return: The created figures and axes.
"""
score_dict, env_steps, score_thresholds = self._get_rliable_data(
algo_name,
score_thresholds,
)
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)
# Plot IQM sample efficiency curve
fig1, ax1 = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_sample_efficiency_curve(
env_steps,
iqm_scores,
iqm_cis,
algorithms=None,
xlabel=r"Number of env steps",
ylabel="IQM episode return",
ax=ax1,
)
if save_figure:
plt.savefig(os.path.join(self.exp_dir, "iqm_sample_efficiency_curve.png"))
final_score_dict = {algo: returns[:, [-1]] for algo, returns in score_dict.items()}
score_distributions, score_distributions_cis = rly.create_performance_profile(
final_score_dict,
score_thresholds,
)
# Plot score distributions
fig2, ax2 = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_performance_profiles(
score_distributions,
score_thresholds,
performance_profile_cis=score_distributions_cis,
xlabel=r"Episode return $(\tau)$",
ax=ax2,
)
if save_figure:
plt.savefig(os.path.join(self.exp_dir, "performance_profile.png"))
return fig1, ax1, fig2, ax2