pleased the mypy gods
This commit is contained in:
parent
ce5fa0dfac
commit
9c645ff4a0
@ -7,9 +7,8 @@ from typing import Literal
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.env import VectorEnvType
|
|
||||||
from tianshou.highlevel.evaluation import RLiableExperimentResult
|
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
|
from tianshou.highlevel.evaluation import RLiableExperimentResult
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
@ -65,6 +64,7 @@ def main(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_train_envs=training_num,
|
num_train_envs=training_num,
|
||||||
num_test_envs=test_num,
|
num_test_envs=test_num,
|
||||||
|
num_test_episodes=test_num,
|
||||||
buffer_size=buffer_size,
|
buffer_size=buffer_size,
|
||||||
step_per_collect=step_per_collect,
|
step_per_collect=step_per_collect,
|
||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
@ -75,7 +75,6 @@ def main(
|
|||||||
train_seed=sampling_config.train_seed,
|
train_seed=sampling_config.train_seed,
|
||||||
test_seed=sampling_config.test_seed,
|
test_seed=sampling_config.test_seed,
|
||||||
obs_norm=True,
|
obs_norm=True,
|
||||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
experiments = (
|
experiments = (
|
||||||
|
@ -21,7 +21,7 @@ class TestTensorBoardLogger:
|
|||||||
| dict[str, dict[str, dict[str, int]]],
|
| dict[str, dict[str, dict[str, int]]],
|
||||||
expected_output: dict[str, int],
|
expected_output: dict[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
logger = TensorboardLogger(SummaryWriter("log/logger"))
|
||||||
result = logger.prepare_dict_for_logging(input_dict)
|
result = logger.prepare_dict_for_logging(input_dict)
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ class TestTensorBoardLogger:
|
|||||||
delimiter: Literal["|", "."],
|
delimiter: Literal["|", "."],
|
||||||
expected_output: dict[str, int],
|
expected_output: dict[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
logger = TensorboardLogger(SummaryWriter("log/logger"))
|
||||||
result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
|
result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ class TestTensorBoardLogger:
|
|||||||
exclude_arrays: bool,
|
exclude_arrays: bool,
|
||||||
expected_output: dict[str, np.ndarray],
|
expected_output: dict[str, np.ndarray],
|
||||||
) -> None:
|
) -> None:
|
||||||
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
logger = TensorboardLogger(SummaryWriter("log/logger"))
|
||||||
result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
|
result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
|
||||||
assert result.keys() == expected_output.keys()
|
assert result.keys() == expected_output.keys()
|
||||||
for val1, val2 in zip(result.values(), expected_output.values(), strict=True):
|
for val1, val2 in zip(result.values(), expected_output.values(), strict=True):
|
||||||
@ -76,6 +76,6 @@ class TestTensorBoardLogger:
|
|||||||
input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]],
|
input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]],
|
||||||
expected_output: dict[str, int],
|
expected_output: dict[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
logger = TensorboardLogger(SummaryWriter("log/logger"))
|
||||||
result = logger.prepare_dict_for_logging(input_dict)
|
result = logger.prepare_dict_for_logging(input_dict)
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
@ -1,11 +1,46 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy.stats as sst
|
||||||
|
from rliable import library as rly
|
||||||
|
from rliable import plot_utils
|
||||||
|
|
||||||
from tianshou.highlevel.experiment import Experiment
|
from tianshou.highlevel.experiment import Experiment
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoggedSummaryData:
|
||||||
|
mean: np.ndarray
|
||||||
|
std: np.ndarray
|
||||||
|
max: np.ndarray
|
||||||
|
min: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoggedCollectStats:
|
||||||
|
env_step: np.ndarray
|
||||||
|
n_collected_episodes: np.ndarray
|
||||||
|
n_collected_steps: np.ndarray
|
||||||
|
collect_time: np.ndarray
|
||||||
|
collect_speed: np.ndarray
|
||||||
|
returns_stat: LoggedSummaryData
|
||||||
|
lens_stat: LoggedSummaryData
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_data_dict(cls, data: dict) -> "LoggedCollectStats":
|
||||||
|
return cls(
|
||||||
|
env_step=np.array(data["env_step"]),
|
||||||
|
n_collected_episodes=np.array(data["n_collected_episodes"]),
|
||||||
|
n_collected_steps=np.array(data["n_collected_steps"]),
|
||||||
|
collect_time=np.array(data["collect_time"]),
|
||||||
|
collect_speed=np.array(data["collect_speed"]),
|
||||||
|
returns_stat=LoggedSummaryData(**data["returns_stat"]),
|
||||||
|
lens_stat=LoggedSummaryData(**data["lens_stat"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RLiableExperimentResult:
|
class RLiableExperimentResult:
|
||||||
"""The result of an experiment that can be used with the rliable library."""
|
"""The result of an experiment that can be used with the rliable library."""
|
||||||
@ -26,7 +61,7 @@ class RLiableExperimentResult:
|
|||||||
:param exp_dir: The directory from where the experiment results are restored.
|
:param exp_dir: The directory from where the experiment results are restored.
|
||||||
"""
|
"""
|
||||||
test_episode_returns = []
|
test_episode_returns = []
|
||||||
test_data = None
|
env_step_at_test = None
|
||||||
|
|
||||||
for entry in os.scandir(exp_dir):
|
for entry in os.scandir(exp_dir):
|
||||||
if entry.name.startswith(".") or not entry.is_dir():
|
if entry.name.startswith(".") or not entry.is_dir():
|
||||||
@ -41,22 +76,27 @@ class RLiableExperimentResult:
|
|||||||
)
|
)
|
||||||
data = logger.restore_logged_data(entry.path)
|
data = logger.restore_logged_data(entry.path)
|
||||||
|
|
||||||
test_data = data["test"]
|
if "test" not in data or not data["test"]:
|
||||||
|
continue
|
||||||
|
test_data = LoggedCollectStats.from_data_dict(data["test"])
|
||||||
|
|
||||||
test_episode_returns.append(test_data["returns_stat"]["mean"])
|
test_episode_returns.append(test_data.returns_stat.mean)
|
||||||
|
env_step_at_test = test_data.env_step
|
||||||
|
|
||||||
if test_data is None:
|
if not test_episode_returns or env_step_at_test is None:
|
||||||
raise ValueError(f"No experiment data found in {exp_dir}.")
|
raise ValueError(f"No experiment data found in {exp_dir}.")
|
||||||
|
|
||||||
env_step = test_data["env_step"]
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
test_episode_returns_RE=np.array(test_episode_returns),
|
test_episode_returns_RE=np.array(test_episode_returns),
|
||||||
env_steps_E=np.array(env_step),
|
env_steps_E=np.array(env_step_at_test),
|
||||||
exp_dir=exp_dir,
|
exp_dir=exp_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray):
|
def _get_rliable_data(
|
||||||
|
self,
|
||||||
|
algo_name: str | None = None,
|
||||||
|
score_thresholds: np.ndarray | None = None,
|
||||||
|
) -> tuple[dict, np.ndarray, np.ndarray]:
|
||||||
"""Return the data in the format expected by the rliable library.
|
"""Return the data in the format expected by the rliable library.
|
||||||
|
|
||||||
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
||||||
@ -67,7 +107,11 @@ class RLiableExperimentResult:
|
|||||||
:return: A tuple score_dict, env_steps, and score_thresholds.
|
:return: A tuple score_dict, env_steps, and score_thresholds.
|
||||||
"""
|
"""
|
||||||
if score_thresholds is None:
|
if score_thresholds is None:
|
||||||
score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101)
|
score_thresholds = np.linspace(
|
||||||
|
np.min(self.test_episode_returns_RE),
|
||||||
|
np.max(self.test_episode_returns_RE),
|
||||||
|
101,
|
||||||
|
)
|
||||||
|
|
||||||
if algo_name is None:
|
if algo_name is None:
|
||||||
algo_name = os.path.basename(self.exp_dir)
|
algo_name = os.path.basename(self.exp_dir)
|
||||||
@ -76,22 +120,26 @@ class RLiableExperimentResult:
|
|||||||
|
|
||||||
return score_dict, self.env_steps_E, score_thresholds
|
return score_dict, self.env_steps_E, score_thresholds
|
||||||
|
|
||||||
def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False):
|
def eval_results(
|
||||||
|
self,
|
||||||
|
algo_name: str | None = None,
|
||||||
|
score_thresholds: np.ndarray | None = None,
|
||||||
|
save_figure: bool = False,
|
||||||
|
) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]:
|
||||||
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.
|
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.
|
||||||
|
|
||||||
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
|
||||||
is set to the experiment dir.
|
is set to the experiment dir.
|
||||||
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
|
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
|
||||||
from the minimum and maximum test episode returns.
|
from the minimum and maximum test episode returns.
|
||||||
|
:param save_figure: If True, the figures are saved to the experiment directory.
|
||||||
|
|
||||||
:return: The created figures and axes.
|
:return: The created figures and axes.
|
||||||
"""
|
"""
|
||||||
import matplotlib.pyplot as plt
|
score_dict, env_steps, score_thresholds = self._get_rliable_data(
|
||||||
import scipy.stats as sst
|
algo_name,
|
||||||
from rliable import library as rly
|
score_thresholds,
|
||||||
from rliable import plot_utils
|
)
|
||||||
|
|
||||||
score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds)
|
|
||||||
|
|
||||||
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
||||||
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)
|
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)
|
||||||
|
@ -145,7 +145,10 @@ class BaseLogger(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]:
|
def restore_logged_data(
|
||||||
|
self,
|
||||||
|
log_path: str,
|
||||||
|
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
|
||||||
"""Load the logged data from disk for post-processing.
|
"""Load the logged data from disk for post-processing.
|
||||||
|
|
||||||
:return: a dict containing the logged data.
|
:return: a dict containing the logged data.
|
||||||
|
@ -95,7 +95,10 @@ class PandasLogger(BaseLogger):
|
|||||||
|
|
||||||
return epoch, env_step, gradient_step
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
|
def restore_logged_data(
|
||||||
|
self,
|
||||||
|
log_path: str,
|
||||||
|
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
|
def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
|
||||||
@ -118,4 +121,5 @@ class PandasLogger(BaseLogger):
|
|||||||
data[scope] = merge_dicts(dict_list)
|
data[scope] = merge_dicts(dict_list)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logging.warning(f"Failed to restore {scope} data")
|
logging.warning(f"Failed to restore {scope} data")
|
||||||
|
data[scope] = {}
|
||||||
return data
|
return data
|
||||||
|
@ -132,14 +132,17 @@ class TensorboardLogger(BaseLogger):
|
|||||||
|
|
||||||
return epoch, env_step, gradient_step
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
|
def restore_logged_data(
|
||||||
|
self,
|
||||||
|
log_path: str,
|
||||||
|
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
|
||||||
ea = event_accumulator.EventAccumulator(log_path)
|
ea = event_accumulator.EventAccumulator(log_path)
|
||||||
ea.Reload()
|
ea.Reload()
|
||||||
|
|
||||||
def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
|
def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
|
||||||
current_dict = data_dict
|
current_dict = data_dict
|
||||||
for key in keys[:-1]:
|
for k in keys[:-1]:
|
||||||
current_dict = current_dict.setdefault(key, {})
|
current_dict = current_dict.setdefault(k, {})
|
||||||
current_dict[keys[-1]] = value
|
current_dict[keys[-1]] = value
|
||||||
|
|
||||||
data: dict[str, Any] = {}
|
data: dict[str, Any] = {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user