pleased the mypy gods

This commit is contained in:
Maximilian Huettenrauch 2024-03-27 15:37:19 +01:00
parent ce5fa0dfac
commit 9c645ff4a0
6 changed files with 85 additions and 28 deletions

View File

@ -7,9 +7,8 @@ from typing import Literal
import torch
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.evaluation import RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.evaluation import RLiableExperimentResult
from tianshou.highlevel.experiment import (
ExperimentConfig,
PPOExperimentBuilder,
@ -65,6 +64,7 @@ def main(
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
num_test_episodes=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
@ -75,7 +75,6 @@ def main(
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
)
experiments = (

View File

@ -21,7 +21,7 @@ class TestTensorBoardLogger:
| dict[str, dict[str, dict[str, int]]],
expected_output: dict[str, int],
) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger'))
logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict)
assert result == expected_output
@ -38,7 +38,7 @@ class TestTensorBoardLogger:
delimiter: Literal["|", "."],
expected_output: dict[str, int],
) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger'))
logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
assert result == expected_output
@ -59,7 +59,7 @@ class TestTensorBoardLogger:
exclude_arrays: bool,
expected_output: dict[str, np.ndarray],
) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger'))
logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
assert result.keys() == expected_output.keys()
for val1, val2 in zip(result.values(), expected_output.values(), strict=True):
@ -76,6 +76,6 @@ class TestTensorBoardLogger:
input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]],
expected_output: dict[str, int],
) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger'))
logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict)
assert result == expected_output

View File

@ -1,11 +1,46 @@
import os
from dataclasses import asdict, dataclass
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as sst
from rliable import library as rly
from rliable import plot_utils
from tianshou.highlevel.experiment import Experiment
@dataclass
class LoggedSummaryData:
mean: np.ndarray
std: np.ndarray
max: np.ndarray
min: np.ndarray
@dataclass
class LoggedCollectStats:
env_step: np.ndarray
n_collected_episodes: np.ndarray
n_collected_steps: np.ndarray
collect_time: np.ndarray
collect_speed: np.ndarray
returns_stat: LoggedSummaryData
lens_stat: LoggedSummaryData
@classmethod
def from_data_dict(cls, data: dict) -> "LoggedCollectStats":
return cls(
env_step=np.array(data["env_step"]),
n_collected_episodes=np.array(data["n_collected_episodes"]),
n_collected_steps=np.array(data["n_collected_steps"]),
collect_time=np.array(data["collect_time"]),
collect_speed=np.array(data["collect_speed"]),
returns_stat=LoggedSummaryData(**data["returns_stat"]),
lens_stat=LoggedSummaryData(**data["lens_stat"]),
)
@dataclass
class RLiableExperimentResult:
"""The result of an experiment that can be used with the rliable library."""
@ -26,7 +61,7 @@ class RLiableExperimentResult:
:param exp_dir: The directory from where the experiment results are restored.
"""
test_episode_returns = []
test_data = None
env_step_at_test = None
for entry in os.scandir(exp_dir):
if entry.name.startswith(".") or not entry.is_dir():
@ -41,22 +76,27 @@ class RLiableExperimentResult:
)
data = logger.restore_logged_data(entry.path)
test_data = data["test"]
if "test" not in data or not data["test"]:
continue
test_data = LoggedCollectStats.from_data_dict(data["test"])
test_episode_returns.append(test_data["returns_stat"]["mean"])
test_episode_returns.append(test_data.returns_stat.mean)
env_step_at_test = test_data.env_step
if test_data is None:
if not test_episode_returns or env_step_at_test is None:
raise ValueError(f"No experiment data found in {exp_dir}.")
env_step = test_data["env_step"]
return cls(
test_episode_returns_RE=np.array(test_episode_returns),
env_steps_E=np.array(env_step),
env_steps_E=np.array(env_step_at_test),
exp_dir=exp_dir,
)
def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray):
def _get_rliable_data(
self,
algo_name: str | None = None,
score_thresholds: np.ndarray | None = None,
) -> tuple[dict, np.ndarray, np.ndarray]:
"""Return the data in the format expected by the rliable library.
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
@ -67,7 +107,11 @@ class RLiableExperimentResult:
:return: A tuple score_dict, env_steps, and score_thresholds.
"""
if score_thresholds is None:
score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101)
score_thresholds = np.linspace(
np.min(self.test_episode_returns_RE),
np.max(self.test_episode_returns_RE),
101,
)
if algo_name is None:
algo_name = os.path.basename(self.exp_dir)
@ -76,22 +120,26 @@ class RLiableExperimentResult:
return score_dict, self.env_steps_E, score_thresholds
def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False):
def eval_results(
self,
algo_name: str | None = None,
score_thresholds: np.ndarray | None = None,
save_figure: bool = False,
) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]:
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
is set to the experiment dir.
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
from the minimum and maximum test episode returns.
:param save_figure: If True, the figures are saved to the experiment directory.
:return: The created figures and axes.
"""
import matplotlib.pyplot as plt
import scipy.stats as sst
from rliable import library as rly
from rliable import plot_utils
score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds)
score_dict, env_steps, score_thresholds = self._get_rliable_data(
algo_name,
score_thresholds,
)
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)

View File

@ -145,7 +145,10 @@ class BaseLogger(ABC):
"""
@abstractmethod
def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]:
def restore_logged_data(
self,
log_path: str,
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
"""Load the logged data from disk for post-processing.
:return: a dict containing the logged data.

View File

@ -95,7 +95,10 @@ class PandasLogger(BaseLogger):
return epoch, env_step, gradient_step
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
def restore_logged_data(
self,
log_path: str,
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
data = {}
def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
@ -118,4 +121,5 @@ class PandasLogger(BaseLogger):
data[scope] = merge_dicts(dict_list)
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
data[scope] = {}
return data

View File

@ -132,14 +132,17 @@ class TensorboardLogger(BaseLogger):
return epoch, env_step, gradient_step
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
def restore_logged_data(
self,
log_path: str,
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
ea = event_accumulator.EventAccumulator(log_path)
ea.Reload()
def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
current_dict = data_dict
for key in keys[:-1]:
current_dict = current_dict.setdefault(key, {})
for k in keys[:-1]:
current_dict = current_dict.setdefault(k, {})
current_dict[keys[-1]] = value
data: dict[str, Any] = {}