From 9c645ff4a0fe323996df300cefa5d7daf47917c3 Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Wed, 27 Mar 2024 15:37:19 +0100 Subject: [PATCH] pleased the mypy gods --- examples/mujoco/mujoco_ppo_hl_multi.py | 5 +- test/base/test_logger.py | 8 +-- tianshou/highlevel/evaluation.py | 80 ++++++++++++++++++++------ tianshou/utils/logger/base.py | 5 +- tianshou/utils/logger/pandas_logger.py | 6 +- tianshou/utils/logger/tensorboard.py | 9 ++- 6 files changed, 85 insertions(+), 28 deletions(-) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index f25317f..703cf68 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -7,9 +7,8 @@ from typing import Literal import torch from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.env import VectorEnvType -from tianshou.highlevel.evaluation import RLiableExperimentResult from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.evaluation import RLiableExperimentResult from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, @@ -65,6 +64,7 @@ def main( batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, + num_test_episodes=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, repeat_per_collect=repeat_per_collect, @@ -75,7 +75,6 @@ def main( train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True, - venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT ) experiments = ( diff --git a/test/base/test_logger.py b/test/base/test_logger.py index 40dfbbb..8fbbb6b 100644 --- a/test/base/test_logger.py +++ b/test/base/test_logger.py @@ -21,7 +21,7 @@ class TestTensorBoardLogger: | dict[str, dict[str, dict[str, int]]], expected_output: dict[str, int], ) -> None: - logger = TensorboardLogger(SummaryWriter('log/logger')) + logger = TensorboardLogger(SummaryWriter("log/logger")) result = logger.prepare_dict_for_logging(input_dict) assert result == expected_output @@ -38,7 +38,7 @@ class TestTensorBoardLogger: delimiter: Literal["|", "."], expected_output: dict[str, int], ) -> None: - logger = TensorboardLogger(SummaryWriter('log/logger')) + logger = TensorboardLogger(SummaryWriter("log/logger")) result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter) assert result == expected_output @@ -59,7 +59,7 @@ class TestTensorBoardLogger: exclude_arrays: bool, expected_output: dict[str, np.ndarray], ) -> None: - logger = TensorboardLogger(SummaryWriter('log/logger')) + logger = TensorboardLogger(SummaryWriter("log/logger")) result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays) assert result.keys() == expected_output.keys() for val1, val2 in zip(result.values(), expected_output.values(), strict=True): @@ -76,6 +76,6 @@ class TestTensorBoardLogger: input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]], expected_output: dict[str, int], ) -> None: - logger = TensorboardLogger(SummaryWriter('log/logger')) + logger = TensorboardLogger(SummaryWriter("log/logger")) result = logger.prepare_dict_for_logging(input_dict) assert result == expected_output diff --git a/tianshou/highlevel/evaluation.py b/tianshou/highlevel/evaluation.py index cb0a1e7..4d4e5c5 100644 --- a/tianshou/highlevel/evaluation.py +++ b/tianshou/highlevel/evaluation.py @@ -1,11 +1,46 @@ import os from dataclasses import asdict, dataclass +import matplotlib.pyplot as plt import numpy as np +import scipy.stats as sst +from rliable import library as rly +from rliable import plot_utils from tianshou.highlevel.experiment import Experiment +@dataclass +class LoggedSummaryData: + mean: np.ndarray + std: np.ndarray + max: np.ndarray + min: np.ndarray + + +@dataclass +class LoggedCollectStats: + env_step: np.ndarray + n_collected_episodes: np.ndarray + n_collected_steps: np.ndarray + collect_time: np.ndarray + collect_speed: np.ndarray + returns_stat: LoggedSummaryData + lens_stat: LoggedSummaryData + + @classmethod + def from_data_dict(cls, data: dict) -> "LoggedCollectStats": + return cls( + env_step=np.array(data["env_step"]), + n_collected_episodes=np.array(data["n_collected_episodes"]), + n_collected_steps=np.array(data["n_collected_steps"]), + collect_time=np.array(data["collect_time"]), + collect_speed=np.array(data["collect_speed"]), + returns_stat=LoggedSummaryData(**data["returns_stat"]), + lens_stat=LoggedSummaryData(**data["lens_stat"]), + ) + + @dataclass class RLiableExperimentResult: """The result of an experiment that can be used with the rliable library.""" @@ -26,7 +61,7 @@ class RLiableExperimentResult: :param exp_dir: The directory from where the experiment results are restored. """ test_episode_returns = [] - test_data = None + env_step_at_test = None for entry in os.scandir(exp_dir): if entry.name.startswith(".") or not entry.is_dir(): @@ -41,22 +76,27 @@ class RLiableExperimentResult: ) data = logger.restore_logged_data(entry.path) - test_data = data["test"] + if "test" not in data or not data["test"]: + continue + test_data = LoggedCollectStats.from_data_dict(data["test"]) - test_episode_returns.append(test_data["returns_stat"]["mean"]) + test_episode_returns.append(test_data.returns_stat.mean) + env_step_at_test = test_data.env_step - if test_data is None: + if not test_episode_returns or env_step_at_test is None: raise ValueError(f"No experiment data found in {exp_dir}.") - env_step = test_data["env_step"] - return cls( test_episode_returns_RE=np.array(test_episode_returns), - env_steps_E=np.array(env_step), + env_steps_E=np.array(env_step_at_test), exp_dir=exp_dir, ) - def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray): + def _get_rliable_data( + self, + algo_name: str | None = None, + score_thresholds: np.ndarray | None = None, + ) -> tuple[dict, np.ndarray, np.ndarray]: """Return the data in the format expected by the rliable library. :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm @@ -67,7 +107,11 @@ class RLiableExperimentResult: :return: A tuple score_dict, env_steps, and score_thresholds. """ if score_thresholds is None: - score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101) + score_thresholds = np.linspace( + np.min(self.test_episode_returns_RE), + np.max(self.test_episode_returns_RE), + 101, + ) if algo_name is None: algo_name = os.path.basename(self.exp_dir) @@ -76,22 +120,26 @@ class RLiableExperimentResult: return score_dict, self.env_steps_E, score_thresholds - def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False): + def eval_results( + self, + algo_name: str | None = None, + score_thresholds: np.ndarray | None = None, + save_figure: bool = False, + ) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]: """Evaluate the results of an experiment and create a sample efficiency curve and a performance profile. :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm is set to the experiment dir. :param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred from the minimum and maximum test episode returns. + :param save_figure: If True, the figures are saved to the experiment directory. :return: The created figures and axes. """ - import matplotlib.pyplot as plt - import scipy.stats as sst - from rliable import library as rly - from rliable import plot_utils - - score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds) + score_dict, env_steps, score_thresholds = self._get_rliable_data( + algo_name, + score_thresholds, + ) iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm) diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 36d321f..be145fd 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -145,7 +145,10 @@ class BaseLogger(ABC): """ @abstractmethod - def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]: + def restore_logged_data( + self, + log_path: str, + ) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]: """Load the logged data from disk for post-processing. :return: a dict containing the logged data. diff --git a/tianshou/utils/logger/pandas_logger.py b/tianshou/utils/logger/pandas_logger.py index 25feaab..b1d9c72 100644 --- a/tianshou/utils/logger/pandas_logger.py +++ b/tianshou/utils/logger/pandas_logger.py @@ -95,7 +95,10 @@ class PandasLogger(BaseLogger): return epoch, env_step, gradient_step - def restore_logged_data(self, log_path: str) -> dict[str, Any]: + def restore_logged_data( + self, + log_path: str, + ) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]: data = {} def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]: @@ -118,4 +121,5 @@ class PandasLogger(BaseLogger): data[scope] = merge_dicts(dict_list) except FileNotFoundError: logging.warning(f"Failed to restore {scope} data") + data[scope] = {} return data diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index dcf0563..098c02d 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -132,14 +132,17 @@ class TensorboardLogger(BaseLogger): return epoch, env_step, gradient_step - def restore_logged_data(self, log_path: str) -> dict[str, Any]: + def restore_logged_data( + self, + log_path: str, + ) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]: ea = event_accumulator.EventAccumulator(log_path) ea.Reload() def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None: current_dict = data_dict - for key in keys[:-1]: - current_dict = current_dict.setdefault(key, {}) + for k in keys[:-1]: + current_dict = current_dict.setdefault(k, {}) current_dict[keys[-1]] = value data: dict[str, Any] = {}