pleased the mypy gods
This commit is contained in:
parent
ce5fa0dfac
commit
9c645ff4a0
@ -7,9 +7,8 @@ from typing import Literal
|
||||
import torch
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.env import VectorEnvType
|
||||
from tianshou.highlevel.evaluation import RLiableExperimentResult
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.evaluation import RLiableExperimentResult
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
PPOExperimentBuilder,
|
||||
@ -65,6 +64,7 @@ def main(
|
||||
batch_size=batch_size,
|
||||
num_train_envs=training_num,
|
||||
num_test_envs=test_num,
|
||||
num_test_episodes=test_num,
|
||||
buffer_size=buffer_size,
|
||||
step_per_collect=step_per_collect,
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
@ -75,7 +75,6 @@ def main(
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
|
||||
)
|
||||
|
||||
experiments = (
|
||||
|
@ -21,7 +21,7 @@ class TestTensorBoardLogger:
|
||||
| dict[str, dict[str, dict[str, int]]],
|
||||
expected_output: dict[str, int],
|
||||
) -> None:
|
||||
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
||||
logger = TensorboardLogger(SummaryWriter("log/logger"))
|
||||
result = logger.prepare_dict_for_logging(input_dict)
|
||||
assert result == expected_output
|
||||
|
||||
@ -38,7 +38,7 @@ class TestTensorBoardLogger:
|
||||
delimiter: Literal["|", "."],
|
||||
expected_output: dict[str, int],
|
||||
) -> None:
|
||||
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
||||
logger = TensorboardLogger(SummaryWriter("log/logger"))
|
||||
result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
|
||||
assert result == expected_output
|
||||
|
||||
@ -59,7 +59,7 @@ class TestTensorBoardLogger:
|
||||
exclude_arrays: bool,
|
||||
expected_output: dict[str, np.ndarray],
|
||||
) -> None:
|
||||
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
||||
logger = TensorboardLogger(SummaryWriter("log/logger"))
|
||||
result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
|
||||
assert result.keys() == expected_output.keys()
|
||||
for val1, val2 in zip(result.values(), expected_output.values(), strict=True):
|
||||
@ -76,6 +76,6 @@ class TestTensorBoardLogger:
|
||||
input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]],
|
||||
expected_output: dict[str, int],
|
||||
) -> None:
|
||||
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
||||
logger = TensorboardLogger(SummaryWriter("log/logger"))
|
||||
result = logger.prepare_dict_for_logging(input_dict)
|
||||
assert result == expected_output
|
||||
|
@ -1,11 +1,46 @@
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import scipy.stats as sst
|
||||
from rliable import library as rly
|
||||
from rliable import plot_utils
|
||||
|
||||
from tianshou.highlevel.experiment import Experiment
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggedSummaryData:
|
||||
mean: np.ndarray
|
||||
std: np.ndarray
|
||||
max: np.ndarray
|
||||
min: np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggedCollectStats:
|
||||
env_step: np.ndarray
|
||||
n_collected_episodes: np.ndarray
|
||||
n_collected_steps: np.ndarray
|
||||
collect_time: np.ndarray
|
||||
collect_speed: np.ndarray
|
||||
returns_stat: LoggedSummaryData
|
||||
lens_stat: LoggedSummaryData
|
||||
|
||||
@classmethod
|
||||
def from_data_dict(cls, data: dict) -> "LoggedCollectStats":
|
||||
return cls(
|
||||
env_step=np.array(data["env_step"]),
|
||||
n_collected_episodes=np.array(data["n_collected_episodes"]),
|
||||
n_collected_steps=np.array(data["n_collected_steps"]),
|
||||
collect_time=np.array(data["collect_time"]),
|
||||
collect_speed=np.array(data["collect_speed"]),
|
||||
returns_stat=LoggedSummaryData(**data["returns_stat"]),
|
||||
lens_stat=LoggedSummaryData(**data["lens_stat"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLiableExperimentResult:
|
||||
"""The result of an experiment that can be used with the rliable library."""
|
||||
@ -26,7 +61,7 @@ class RLiableExperimentResult:
|
||||
:param exp_dir: The directory from where the experiment results are restored.
|
||||
"""
|
||||
test_episode_returns = []
|
||||
test_data = None
|
||||
env_step_at_test = None
|
||||
|
||||
for entry in os.scandir(exp_dir):
|
||||
if entry.name.startswith(".") or not entry.is_dir():
|
||||
@ -41,22 +76,27 @@ class RLiableExperimentResult:
|
||||
)
|
||||
data = logger.restore_logged_data(entry.path)
|
||||
|
||||
test_data = data["test"]
|
||||
if "test" not in data or not data["test"]:
|
||||
continue
|
||||
test_data = LoggedCollectStats.from_data_dict(data["test"])
|
||||
|
||||
test_episode_returns.append(test_data["returns_stat"]["mean"])
|
||||
test_episode_returns.append(test_data.returns_stat.mean)
|
||||
env_step_at_test = test_data.env_step
|
||||
|
||||
if test_data is None:
|
||||
if not test_episode_returns or env_step_at_test is None:
|
||||
raise ValueError(f"No experiment data found in {exp_dir}.")
|
||||
|
||||
env_step = test_data["env_step"]
|
||||
|
||||
return cls(
|
||||
test_episode_returns_RE=np.array(test_episode_returns),
|
||||
env_steps_E=np.array(env_step),
|
||||
env_steps_E=np.array(env_step_at_test),
|
||||
exp_dir=exp_dir,
|
||||
)
|
||||
|
||||
def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray):
|
||||
def _get_rliable_data(
|
||||
self,
|
||||
algo_name: str | None = None,
|
||||
score_thresholds: np.ndarray | None = None,
|
||||
) -> tuple[dict, np.ndarray, np.ndarray]:
|
||||
"""Return the data in the format expected by the rliable library.
|
||||
|
||||
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
||||
@ -67,7 +107,11 @@ class RLiableExperimentResult:
|
||||
:return: A tuple score_dict, env_steps, and score_thresholds.
|
||||
"""
|
||||
if score_thresholds is None:
|
||||
score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101)
|
||||
score_thresholds = np.linspace(
|
||||
np.min(self.test_episode_returns_RE),
|
||||
np.max(self.test_episode_returns_RE),
|
||||
101,
|
||||
)
|
||||
|
||||
if algo_name is None:
|
||||
algo_name = os.path.basename(self.exp_dir)
|
||||
@ -76,22 +120,26 @@ class RLiableExperimentResult:
|
||||
|
||||
return score_dict, self.env_steps_E, score_thresholds
|
||||
|
||||
def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False):
|
||||
def eval_results(
|
||||
self,
|
||||
algo_name: str | None = None,
|
||||
score_thresholds: np.ndarray | None = None,
|
||||
save_figure: bool = False,
|
||||
) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]:
|
||||
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.
|
||||
|
||||
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
||||
is set to the experiment dir.
|
||||
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
|
||||
from the minimum and maximum test episode returns.
|
||||
:param save_figure: If True, the figures are saved to the experiment directory.
|
||||
|
||||
:return: The created figures and axes.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.stats as sst
|
||||
from rliable import library as rly
|
||||
from rliable import plot_utils
|
||||
|
||||
score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds)
|
||||
score_dict, env_steps, score_thresholds = self._get_rliable_data(
|
||||
algo_name,
|
||||
score_thresholds,
|
||||
)
|
||||
|
||||
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
||||
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)
|
||||
|
@ -145,7 +145,10 @@ class BaseLogger(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||
def restore_logged_data(
|
||||
self,
|
||||
log_path: str,
|
||||
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
|
||||
"""Load the logged data from disk for post-processing.
|
||||
|
||||
:return: a dict containing the logged data.
|
||||
|
@ -95,7 +95,10 @@ class PandasLogger(BaseLogger):
|
||||
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
|
||||
def restore_logged_data(
|
||||
self,
|
||||
log_path: str,
|
||||
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
|
||||
data = {}
|
||||
|
||||
def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
|
||||
@ -118,4 +121,5 @@ class PandasLogger(BaseLogger):
|
||||
data[scope] = merge_dicts(dict_list)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Failed to restore {scope} data")
|
||||
data[scope] = {}
|
||||
return data
|
||||
|
@ -132,14 +132,17 @@ class TensorboardLogger(BaseLogger):
|
||||
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
|
||||
def restore_logged_data(
|
||||
self,
|
||||
log_path: str,
|
||||
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
|
||||
ea = event_accumulator.EventAccumulator(log_path)
|
||||
ea.Reload()
|
||||
|
||||
def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
|
||||
current_dict = data_dict
|
||||
for key in keys[:-1]:
|
||||
current_dict = current_dict.setdefault(key, {})
|
||||
for k in keys[:-1]:
|
||||
current_dict = current_dict.setdefault(k, {})
|
||||
current_dict[keys[-1]] = value
|
||||
|
||||
data: dict[str, Any] = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user