pleased the mypy gods

This commit is contained in:
Maximilian Huettenrauch 2024-03-27 15:37:19 +01:00
parent ce5fa0dfac
commit 9c645ff4a0
6 changed files with 85 additions and 28 deletions

View File

@ -7,9 +7,8 @@ from typing import Literal
import torch import torch
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.evaluation import RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.evaluation import RLiableExperimentResult
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
ExperimentConfig, ExperimentConfig,
PPOExperimentBuilder, PPOExperimentBuilder,
@ -65,6 +64,7 @@ def main(
batch_size=batch_size, batch_size=batch_size,
num_train_envs=training_num, num_train_envs=training_num,
num_test_envs=test_num, num_test_envs=test_num,
num_test_episodes=test_num,
buffer_size=buffer_size, buffer_size=buffer_size,
step_per_collect=step_per_collect, step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
@ -75,7 +75,6 @@ def main(
train_seed=sampling_config.train_seed, train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed, test_seed=sampling_config.test_seed,
obs_norm=True, obs_norm=True,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
) )
experiments = ( experiments = (

View File

@ -21,7 +21,7 @@ class TestTensorBoardLogger:
| dict[str, dict[str, dict[str, int]]], | dict[str, dict[str, dict[str, int]]],
expected_output: dict[str, int], expected_output: dict[str, int],
) -> None: ) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger')) logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict) result = logger.prepare_dict_for_logging(input_dict)
assert result == expected_output assert result == expected_output
@ -38,7 +38,7 @@ class TestTensorBoardLogger:
delimiter: Literal["|", "."], delimiter: Literal["|", "."],
expected_output: dict[str, int], expected_output: dict[str, int],
) -> None: ) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger')) logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter) result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
assert result == expected_output assert result == expected_output
@ -59,7 +59,7 @@ class TestTensorBoardLogger:
exclude_arrays: bool, exclude_arrays: bool,
expected_output: dict[str, np.ndarray], expected_output: dict[str, np.ndarray],
) -> None: ) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger')) logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays) result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
assert result.keys() == expected_output.keys() assert result.keys() == expected_output.keys()
for val1, val2 in zip(result.values(), expected_output.values(), strict=True): for val1, val2 in zip(result.values(), expected_output.values(), strict=True):
@ -76,6 +76,6 @@ class TestTensorBoardLogger:
input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]], input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]],
expected_output: dict[str, int], expected_output: dict[str, int],
) -> None: ) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger')) logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict) result = logger.prepare_dict_for_logging(input_dict)
assert result == expected_output assert result == expected_output

View File

@ -1,11 +1,46 @@
import os import os
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import scipy.stats as sst
from rliable import library as rly
from rliable import plot_utils
from tianshou.highlevel.experiment import Experiment from tianshou.highlevel.experiment import Experiment
@dataclass
class LoggedSummaryData:
mean: np.ndarray
std: np.ndarray
max: np.ndarray
min: np.ndarray
@dataclass
class LoggedCollectStats:
env_step: np.ndarray
n_collected_episodes: np.ndarray
n_collected_steps: np.ndarray
collect_time: np.ndarray
collect_speed: np.ndarray
returns_stat: LoggedSummaryData
lens_stat: LoggedSummaryData
@classmethod
def from_data_dict(cls, data: dict) -> "LoggedCollectStats":
return cls(
env_step=np.array(data["env_step"]),
n_collected_episodes=np.array(data["n_collected_episodes"]),
n_collected_steps=np.array(data["n_collected_steps"]),
collect_time=np.array(data["collect_time"]),
collect_speed=np.array(data["collect_speed"]),
returns_stat=LoggedSummaryData(**data["returns_stat"]),
lens_stat=LoggedSummaryData(**data["lens_stat"]),
)
@dataclass @dataclass
class RLiableExperimentResult: class RLiableExperimentResult:
"""The result of an experiment that can be used with the rliable library.""" """The result of an experiment that can be used with the rliable library."""
@ -26,7 +61,7 @@ class RLiableExperimentResult:
:param exp_dir: The directory from where the experiment results are restored. :param exp_dir: The directory from where the experiment results are restored.
""" """
test_episode_returns = [] test_episode_returns = []
test_data = None env_step_at_test = None
for entry in os.scandir(exp_dir): for entry in os.scandir(exp_dir):
if entry.name.startswith(".") or not entry.is_dir(): if entry.name.startswith(".") or not entry.is_dir():
@ -41,22 +76,27 @@ class RLiableExperimentResult:
) )
data = logger.restore_logged_data(entry.path) data = logger.restore_logged_data(entry.path)
test_data = data["test"] if "test" not in data or not data["test"]:
continue
test_data = LoggedCollectStats.from_data_dict(data["test"])
test_episode_returns.append(test_data["returns_stat"]["mean"]) test_episode_returns.append(test_data.returns_stat.mean)
env_step_at_test = test_data.env_step
if test_data is None: if not test_episode_returns or env_step_at_test is None:
raise ValueError(f"No experiment data found in {exp_dir}.") raise ValueError(f"No experiment data found in {exp_dir}.")
env_step = test_data["env_step"]
return cls( return cls(
test_episode_returns_RE=np.array(test_episode_returns), test_episode_returns_RE=np.array(test_episode_returns),
env_steps_E=np.array(env_step), env_steps_E=np.array(env_step_at_test),
exp_dir=exp_dir, exp_dir=exp_dir,
) )
def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray): def _get_rliable_data(
self,
algo_name: str | None = None,
score_thresholds: np.ndarray | None = None,
) -> tuple[dict, np.ndarray, np.ndarray]:
"""Return the data in the format expected by the rliable library. """Return the data in the format expected by the rliable library.
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
@ -67,7 +107,11 @@ class RLiableExperimentResult:
:return: A tuple score_dict, env_steps, and score_thresholds. :return: A tuple score_dict, env_steps, and score_thresholds.
""" """
if score_thresholds is None: if score_thresholds is None:
score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101) score_thresholds = np.linspace(
np.min(self.test_episode_returns_RE),
np.max(self.test_episode_returns_RE),
101,
)
if algo_name is None: if algo_name is None:
algo_name = os.path.basename(self.exp_dir) algo_name = os.path.basename(self.exp_dir)
@ -76,22 +120,26 @@ class RLiableExperimentResult:
return score_dict, self.env_steps_E, score_thresholds return score_dict, self.env_steps_E, score_thresholds
def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False): def eval_results(
self,
algo_name: str | None = None,
score_thresholds: np.ndarray | None = None,
save_figure: bool = False,
) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]:
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile. """Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
is set to the experiment dir. is set to the experiment dir.
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred :param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
from the minimum and maximum test episode returns. from the minimum and maximum test episode returns.
:param save_figure: If True, the figures are saved to the experiment directory.
:return: The created figures and axes. :return: The created figures and axes.
""" """
import matplotlib.pyplot as plt score_dict, env_steps, score_thresholds = self._get_rliable_data(
import scipy.stats as sst algo_name,
from rliable import library as rly score_thresholds,
from rliable import plot_utils )
score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds)
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm) iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)

View File

@ -145,7 +145,10 @@ class BaseLogger(ABC):
""" """
@abstractmethod @abstractmethod
def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]: def restore_logged_data(
self,
log_path: str,
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
"""Load the logged data from disk for post-processing. """Load the logged data from disk for post-processing.
:return: a dict containing the logged data. :return: a dict containing the logged data.

View File

@ -95,7 +95,10 @@ class PandasLogger(BaseLogger):
return epoch, env_step, gradient_step return epoch, env_step, gradient_step
def restore_logged_data(self, log_path: str) -> dict[str, Any]: def restore_logged_data(
self,
log_path: str,
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
data = {} data = {}
def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]: def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
@ -118,4 +121,5 @@ class PandasLogger(BaseLogger):
data[scope] = merge_dicts(dict_list) data[scope] = merge_dicts(dict_list)
except FileNotFoundError: except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data") logging.warning(f"Failed to restore {scope} data")
data[scope] = {}
return data return data

View File

@ -132,14 +132,17 @@ class TensorboardLogger(BaseLogger):
return epoch, env_step, gradient_step return epoch, env_step, gradient_step
def restore_logged_data(self, log_path: str) -> dict[str, Any]: def restore_logged_data(
self,
log_path: str,
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
ea = event_accumulator.EventAccumulator(log_path) ea = event_accumulator.EventAccumulator(log_path)
ea.Reload() ea.Reload()
def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None: def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
current_dict = data_dict current_dict = data_dict
for key in keys[:-1]: for k in keys[:-1]:
current_dict = current_dict.setdefault(key, {}) current_dict = current_dict.setdefault(k, {})
current_dict[keys[-1]] = value current_dict[keys[-1]] = value
data: dict[str, Any] = {} data: dict[str, Any] = {}