restructured and moved RLiableExperimentResult
This commit is contained in:
parent
18d8ffa576
commit
6d9b697efe
@ -7,7 +7,8 @@ from typing import Literal
|
||||
import torch
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from examples.mujoco.tools import RLiableExperimentResult, eval_results
|
||||
from tianshou.highlevel.env import VectorEnvType
|
||||
from tianshou.highlevel.evaluation import RLiableExperimentResult
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
@ -25,12 +26,12 @@ from tianshou.utils.logging import datetime_tag
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v4",
|
||||
num_experiments: int = 2,
|
||||
num_experiments: int = 5,
|
||||
buffer_size: int = 4096,
|
||||
hidden_sizes: Sequence[int] = (64, 64),
|
||||
lr: float = 3e-4,
|
||||
gamma: float = 0.99,
|
||||
epoch: int = 1,
|
||||
epoch: int = 100,
|
||||
step_per_epoch: int = 30000,
|
||||
step_per_collect: int = 2048,
|
||||
repeat_per_collect: int = 10,
|
||||
@ -56,6 +57,7 @@ def main(
|
||||
"""
|
||||
log_name = os.path.join("log", task, "ppo", datetime_tag())
|
||||
experiment_config.persistence_base_dir = log_name
|
||||
experiment_config.watch = False
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
@ -73,6 +75,7 @@ def main(
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
|
||||
)
|
||||
|
||||
experiments = (
|
||||
@ -110,8 +113,9 @@ def main(
|
||||
|
||||
|
||||
def eval_experiments(log_dir: str):
|
||||
results = RLiableExperimentResult.load_from_disk(log_dir, "PPO")
|
||||
eval_results(results, save_figure=True)
|
||||
"""Evaluate the experiments in the given log directory using the rliable API."""
|
||||
rliable_result = RLiableExperimentResult.load_from_disk(log_dir)
|
||||
rliable_result.eval_results(save_figure=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -5,125 +5,11 @@ import csv
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
|
||||
from tianshou.highlevel.experiment import Experiment
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLiableExperimentResult:
|
||||
exp_dir: str
|
||||
algorithms: list[str]
|
||||
score_dict: dict[str, np.ndarray] # (n_runs x n_epochs + 1)
|
||||
env_steps: np.ndarray # (n_epochs + 1)
|
||||
score_thresholds: np.ndarray
|
||||
|
||||
@staticmethod
|
||||
def load_from_disk(exp_dir: str, algo_name: str, score_thresholds: np.ndarray | None = None):
|
||||
"""Load the experiment result from disk.
|
||||
|
||||
:param exp_dir: The directory from where the experiment results are restored.
|
||||
:param algo_name: The name of the algorithm used in the figure legend.
|
||||
:param score_thresholds: The thresholds used to create the performance profile.
|
||||
If None, it will be created from the test episode returns.
|
||||
"""
|
||||
test_episode_returns = []
|
||||
|
||||
for entry in os.scandir(exp_dir):
|
||||
if entry.name.startswith(".") or not entry.is_dir():
|
||||
continue
|
||||
|
||||
exp = Experiment.from_directory(entry.path)
|
||||
logger = exp.logger_factory.create_logger(
|
||||
entry.path,
|
||||
entry.name,
|
||||
None,
|
||||
asdict(exp.config),
|
||||
)
|
||||
data = logger.restore_logged_data(entry.path)
|
||||
|
||||
test_data = data["test"]
|
||||
|
||||
test_episode_returns.append(test_data["returns_stat"]["mean"])
|
||||
env_step = test_data["env_step"]
|
||||
|
||||
if score_thresholds is None:
|
||||
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
|
||||
|
||||
return RLiableExperimentResult(
|
||||
algorithms=[algo_name],
|
||||
score_dict={algo_name: np.array(test_episode_returns)},
|
||||
env_steps=np.array(env_step),
|
||||
score_thresholds=score_thresholds,
|
||||
exp_dir=exp_dir,
|
||||
)
|
||||
|
||||
|
||||
def eval_results(results: RLiableExperimentResult, save_figure=False):
|
||||
"""Evaluate the results of an experiment and create the performance profile and sample efficiency curve.
|
||||
|
||||
:param results: The results of the experiment. Needs to be compatible with the rliable API. This can be achieved by
|
||||
calling the method `load_from_disk` from the RLiableExperimentResult class.
|
||||
:param save_figure: Whether to save the figures as png to the experiment directory.
|
||||
|
||||
:return: The axes of the created figures.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.stats as sst
|
||||
import seaborn as sns
|
||||
from rliable import library as rly
|
||||
from rliable import plot_utils
|
||||
|
||||
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
||||
iqm_scores, iqm_cis = rly.get_interval_estimates(results.score_dict, iqm, reps=50000)
|
||||
|
||||
# Plot IQM sample efficiency curve
|
||||
fig, ax1 = plt.subplots(ncols=1, figsize=(7, 5))
|
||||
plot_utils.plot_sample_efficiency_curve(
|
||||
results.env_steps,
|
||||
iqm_scores,
|
||||
iqm_cis,
|
||||
algorithms=results.algorithms,
|
||||
xlabel=r"Number of env steps",
|
||||
ylabel="IQM episode return",
|
||||
ax=ax1,
|
||||
)
|
||||
|
||||
if save_figure:
|
||||
plt.savefig(os.path.join(results.exp_dir, "iqm_sample_efficiency_curve.png"))
|
||||
|
||||
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
|
||||
score_distributions, score_distributions_cis = rly.create_performance_profile(
|
||||
final_score_dict,
|
||||
results.score_thresholds,
|
||||
)
|
||||
|
||||
# Plot score distributions
|
||||
fig, ax2 = plt.subplots(ncols=1, figsize=(7, 5))
|
||||
plot_utils.plot_performance_profiles(
|
||||
score_distributions,
|
||||
results.score_thresholds,
|
||||
performance_profile_cis=score_distributions_cis,
|
||||
colors=dict(
|
||||
zip(
|
||||
results.algorithms,
|
||||
sns.color_palette("colorblind", n_colors=len(results.algorithms)),
|
||||
strict=True,
|
||||
),
|
||||
),
|
||||
xlabel=r"Episode return $(\tau)$",
|
||||
ax=ax2,
|
||||
)
|
||||
|
||||
if save_figure:
|
||||
plt.savefig(os.path.join(results.exp_dir, "performance_profile.png"))
|
||||
|
||||
return ax1, ax2
|
||||
|
||||
|
||||
def find_all_files(root_dir, pattern):
|
||||
"""Find all files under root_dir according to relative pattern."""
|
||||
|
133
tianshou/highlevel/evaluation.py
Normal file
133
tianshou/highlevel/evaluation.py
Normal file
@ -0,0 +1,133 @@
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tianshou.highlevel.experiment import Experiment
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLiableExperimentResult:
|
||||
"""The result of an experiment that can be used with the rliable library."""
|
||||
|
||||
exp_dir: str
|
||||
"""The base directory where each sub-directory contains the results of one experiment run."""
|
||||
|
||||
test_episode_returns_RE: np.ndarray
|
||||
"""The test episodes for each run of the experiment where each row corresponds to one run."""
|
||||
|
||||
env_steps_E: np.ndarray
|
||||
"""The number of environment steps at which the test episodes were evaluated."""
|
||||
|
||||
@classmethod
|
||||
def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult":
|
||||
"""Load the experiment result from disk.
|
||||
|
||||
:param exp_dir: The directory from where the experiment results are restored.
|
||||
"""
|
||||
test_episode_returns = []
|
||||
test_data = None
|
||||
|
||||
for entry in os.scandir(exp_dir):
|
||||
if entry.name.startswith(".") or not entry.is_dir():
|
||||
continue
|
||||
|
||||
exp = Experiment.from_directory(entry.path)
|
||||
logger = exp.logger_factory.create_logger(
|
||||
entry.path,
|
||||
entry.name,
|
||||
None,
|
||||
asdict(exp.config),
|
||||
)
|
||||
data = logger.restore_logged_data(entry.path)
|
||||
|
||||
test_data = data["test"]
|
||||
|
||||
test_episode_returns.append(test_data["returns_stat"]["mean"])
|
||||
|
||||
if test_data is None:
|
||||
raise ValueError(f"No experiment data found in {exp_dir}.")
|
||||
|
||||
env_step = test_data["env_step"]
|
||||
|
||||
return cls(
|
||||
test_episode_returns_RE=np.array(test_episode_returns),
|
||||
env_steps_E=np.array(env_step),
|
||||
exp_dir=exp_dir,
|
||||
)
|
||||
|
||||
def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray):
|
||||
"""Return the data in the format expected by the rliable library.
|
||||
|
||||
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
||||
is set to the experiment dir.
|
||||
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
|
||||
from the minimum and maximum test episode returns.
|
||||
|
||||
:return: A tuple score_dict, env_steps, and score_thresholds.
|
||||
"""
|
||||
if score_thresholds is None:
|
||||
score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101)
|
||||
|
||||
if algo_name is None:
|
||||
algo_name = os.path.basename(self.exp_dir)
|
||||
|
||||
score_dict = {algo_name: self.test_episode_returns_RE}
|
||||
|
||||
return score_dict, self.env_steps_E, score_thresholds
|
||||
|
||||
def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False):
|
||||
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.
|
||||
|
||||
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
||||
is set to the experiment dir.
|
||||
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
|
||||
from the minimum and maximum test episode returns.
|
||||
|
||||
:return: The created figures and axes.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.stats as sst
|
||||
from rliable import library as rly
|
||||
from rliable import plot_utils
|
||||
|
||||
score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds)
|
||||
|
||||
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
||||
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)
|
||||
|
||||
# Plot IQM sample efficiency curve
|
||||
fig1, ax1 = plt.subplots(ncols=1, figsize=(7, 5))
|
||||
plot_utils.plot_sample_efficiency_curve(
|
||||
env_steps,
|
||||
iqm_scores,
|
||||
iqm_cis,
|
||||
algorithms=None,
|
||||
xlabel=r"Number of env steps",
|
||||
ylabel="IQM episode return",
|
||||
ax=ax1,
|
||||
)
|
||||
|
||||
if save_figure:
|
||||
plt.savefig(os.path.join(self.exp_dir, "iqm_sample_efficiency_curve.png"))
|
||||
|
||||
final_score_dict = {algo: returns[:, [-1]] for algo, returns in score_dict.items()}
|
||||
score_distributions, score_distributions_cis = rly.create_performance_profile(
|
||||
final_score_dict,
|
||||
score_thresholds,
|
||||
)
|
||||
|
||||
# Plot score distributions
|
||||
fig2, ax2 = plt.subplots(ncols=1, figsize=(7, 5))
|
||||
plot_utils.plot_performance_profiles(
|
||||
score_distributions,
|
||||
score_thresholds,
|
||||
performance_profile_cis=score_distributions_cis,
|
||||
xlabel=r"Episode return $(\tau)$",
|
||||
ax=ax2,
|
||||
)
|
||||
|
||||
if save_figure:
|
||||
plt.savefig(os.path.join(self.exp_dir, "performance_profile.png"))
|
||||
|
||||
return fig1, ax1, fig2, ax2
|
Loading…
x
Reference in New Issue
Block a user