format, type check and small fixes
This commit is contained in:
parent
f730782f29
commit
d9a612a997
@ -66,7 +66,13 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
|
||||
env_factory = AtariEnvFactory(
|
||||
task,
|
||||
sampling_config.train_seed,
|
||||
sampling_config.test_seed,
|
||||
frames_stack,
|
||||
scale=scale_obs,
|
||||
)
|
||||
|
||||
builder = (
|
||||
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -65,7 +65,13 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
|
||||
env_factory = AtariEnvFactory(
|
||||
task,
|
||||
sampling_config.train_seed,
|
||||
sampling_config.test_seed,
|
||||
frames_stack,
|
||||
scale=scale_obs,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -65,7 +65,13 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
|
||||
env_factory = AtariEnvFactory(
|
||||
task,
|
||||
sampling_config.train_seed,
|
||||
sampling_config.test_seed,
|
||||
frames_stack,
|
||||
scale=scale_obs,
|
||||
)
|
||||
|
||||
builder = (
|
||||
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -54,7 +54,12 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -51,7 +51,12 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=False,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -56,7 +56,12 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -61,7 +61,12 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -2,13 +2,12 @@
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from examples.mujoco.tools import eval_results, RLiableExperimentResult
|
||||
from examples.mujoco.tools import RLiableExperimentResult, eval_results
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
@ -19,7 +18,6 @@ from tianshou.highlevel.params.dist_fn import (
|
||||
)
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
|
||||
@ -65,7 +63,12 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
)
|
||||
|
||||
experiments = (
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
@ -102,13 +105,13 @@ def main(
|
||||
|
||||
|
||||
def eval_experiments(log_dir: str):
|
||||
results = RLiableExperimentResult.load_from_disk(log_dir, 'PPO', None)
|
||||
results = RLiableExperimentResult.load_from_disk(log_dir, "PPO", None)
|
||||
eval_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# logging.run_cli(main)
|
||||
experiment_config = ExperimentConfig(watch=False)
|
||||
log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
|
||||
# log_dir = <path/to/exp>
|
||||
# experiment_config = ExperimentConfig(watch=False)
|
||||
# log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
|
||||
log_dir = "log/Ant-v4/ppo/20240312-114646"
|
||||
eval_experiments(log_dir)
|
||||
|
@ -57,7 +57,12 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=False,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -49,7 +49,12 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -52,7 +52,12 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=False,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -58,7 +58,12 @@ def main(
|
||||
start_timesteps_random=True,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=False,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -58,7 +58,12 @@ def main(
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
||||
env_factory = MujocoEnvFactory(
|
||||
task,
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
|
@ -5,7 +5,7 @@ import csv
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
@ -34,26 +34,33 @@ class RLiableExperimentResult:
|
||||
test_episode_returns = []
|
||||
|
||||
for entry in os.scandir(exp_dir):
|
||||
if entry.name.startswith('.'):
|
||||
if entry.name.startswith(".") or not entry.is_dir():
|
||||
continue
|
||||
|
||||
exp = Experiment.from_directory(entry.path)
|
||||
logger = exp.logger_factory.create_logger(entry.path, entry.name, None, asdict(exp.config))
|
||||
logger = exp.logger_factory.create_logger(
|
||||
entry.path,
|
||||
entry.name,
|
||||
None,
|
||||
asdict(exp.config),
|
||||
)
|
||||
data = logger.restore_logged_data(entry.path)
|
||||
|
||||
test_data = data['test']
|
||||
test_data = data["test"]
|
||||
|
||||
test_episode_returns.append(test_data['returns_stat']['mean'])
|
||||
env_step = test_data['env_step']
|
||||
test_episode_returns.append(test_data["returns_stat"]["mean"])
|
||||
env_step = test_data["env_step"]
|
||||
|
||||
if score_thresholds is None:
|
||||
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
|
||||
|
||||
return RLiableExperimentResult(algorithms=[algo_name],
|
||||
score_dict={algo_name: np.array(test_episode_returns)},
|
||||
env_steps=np.array(env_step),
|
||||
score_thresholds=score_thresholds,
|
||||
exp_dir=exp_dir)
|
||||
return RLiableExperimentResult(
|
||||
algorithms=[algo_name],
|
||||
score_dict={algo_name: np.array(test_episode_returns)},
|
||||
env_steps=np.array(env_step),
|
||||
score_thresholds=score_thresholds,
|
||||
exp_dir=exp_dir,
|
||||
)
|
||||
|
||||
|
||||
def eval_results(results: RLiableExperimentResult):
|
||||
@ -64,31 +71,44 @@ def eval_results(results: RLiableExperimentResult):
|
||||
from rliable import plot_utils
|
||||
|
||||
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
||||
iqm_scores, iqm_cis = rly.get_interval_estimates(
|
||||
results.score_dict, iqm, reps=50000)
|
||||
iqm_scores, iqm_cis = rly.get_interval_estimates(results.score_dict, iqm, reps=50000)
|
||||
|
||||
# Plot IQM sample efficiency curve
|
||||
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
|
||||
plot_utils.plot_sample_efficiency_curve(
|
||||
results.env_steps, iqm_scores, iqm_cis, algorithms=results.algorithms,
|
||||
xlabel=r'Number of env steps',
|
||||
ylabel='IQM episode return',
|
||||
ax=ax)
|
||||
plt.savefig(os.path.join(results.exp_dir, 'iqm_sample_efficiency_curve.png'))
|
||||
results.env_steps,
|
||||
iqm_scores,
|
||||
iqm_cis,
|
||||
algorithms=results.algorithms,
|
||||
xlabel=r"Number of env steps",
|
||||
ylabel="IQM episode return",
|
||||
ax=ax,
|
||||
)
|
||||
plt.savefig(os.path.join(results.exp_dir, "iqm_sample_efficiency_curve.png"))
|
||||
|
||||
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
|
||||
score_distributions, score_distributions_cis = rly.create_performance_profile(
|
||||
final_score_dict, results.score_thresholds)
|
||||
final_score_dict,
|
||||
results.score_thresholds,
|
||||
)
|
||||
|
||||
# Plot score distributions
|
||||
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
|
||||
plot_utils.plot_performance_profiles(
|
||||
score_distributions, results.score_thresholds,
|
||||
score_distributions,
|
||||
results.score_thresholds,
|
||||
performance_profile_cis=score_distributions_cis,
|
||||
colors=dict(zip(results.algorithms, sns.color_palette('colorblind'))),
|
||||
xlabel=r'Episode return $(\tau)$',
|
||||
ax=ax)
|
||||
plt.savefig(os.path.join(results.exp_dir, 'performance_profile.png'))
|
||||
colors=dict(
|
||||
zip(
|
||||
results.algorithms,
|
||||
sns.color_palette("colorblind", n_colors=len(results.algorithms)),
|
||||
strict=True,
|
||||
),
|
||||
),
|
||||
xlabel=r"Episode return $(\tau)$",
|
||||
ax=ax,
|
||||
)
|
||||
plt.savefig(os.path.join(results.exp_dir, "performance_profile.png"))
|
||||
|
||||
|
||||
def find_all_files(root_dir, pattern):
|
||||
|
@ -6,9 +6,19 @@ from tianshou.highlevel.env import (
|
||||
|
||||
class DiscreteTestEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(task="CartPole-v0", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)
|
||||
super().__init__(
|
||||
task="CartPole-v0",
|
||||
train_seed=42,
|
||||
test_seed=1337,
|
||||
venv_type=VectorEnvType.DUMMY,
|
||||
)
|
||||
|
||||
|
||||
class ContinuousTestEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(task="Pendulum-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)
|
||||
super().__init__(
|
||||
task="Pendulum-v1",
|
||||
train_seed=42,
|
||||
test_seed=1337,
|
||||
venv_type=VectorEnvType.DUMMY,
|
||||
)
|
||||
|
@ -3,7 +3,7 @@ import pickle
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
from typing import Self
|
||||
|
||||
@ -87,9 +87,7 @@ from tianshou.utils.string import ToStringMixin
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def shortener(input_string: str | None = None,
|
||||
length: int = 1
|
||||
):
|
||||
def shortener(input_string: str | None = None, length: int = 1) -> str:
|
||||
"""Shorten the input string by keeping only the first `length` characters of each word.
|
||||
|
||||
If the input string is None or empty, return "default".
|
||||
@ -367,19 +365,19 @@ class ExperimentBuilder:
|
||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||
|
||||
@property
|
||||
def experiment_config(self):
|
||||
def experiment_config(self) -> ExperimentConfig:
|
||||
return self._config
|
||||
|
||||
@experiment_config.setter
|
||||
def experiment_config(self, experiment_config: ExperimentConfig):
|
||||
def experiment_config(self, experiment_config: ExperimentConfig) -> None:
|
||||
self._config = experiment_config
|
||||
|
||||
@property
|
||||
def sampling_config(self):
|
||||
def sampling_config(self) -> SamplingConfig:
|
||||
return self._sampling_config
|
||||
|
||||
@sampling_config.setter
|
||||
def sampling_config(self, sampling_config: SamplingConfig):
|
||||
def sampling_config(self, sampling_config: SamplingConfig) -> None:
|
||||
self._sampling_config = sampling_config
|
||||
|
||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||
@ -491,7 +489,6 @@ class ExperimentBuilder:
|
||||
|
||||
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
|
||||
"""
|
||||
|
||||
configured_experiment_config = copy(self.experiment_config)
|
||||
configured_experiment_seed = configured_experiment_config.seed
|
||||
configured_sampling_config = copy(self.sampling_config)
|
||||
@ -512,9 +509,13 @@ class ExperimentBuilder:
|
||||
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
|
||||
exp = self.build()
|
||||
|
||||
full_name = ",".join([f"experiment_seed={exp.config.seed}",
|
||||
f"train_seed={exp.sampling_config.train_seed}",
|
||||
f"test_seed={exp.sampling_config.test_seed}"])
|
||||
full_name = ",".join(
|
||||
[
|
||||
f"experiment_seed={exp.config.seed}",
|
||||
f"train_seed={exp.sampling_config.train_seed}",
|
||||
f"test_seed={exp.sampling_config.test_seed}",
|
||||
],
|
||||
)
|
||||
experiment_name = shortener(full_name, 4)
|
||||
seeded_experiments[experiment_name] = exp
|
||||
|
||||
|
@ -5,7 +5,6 @@ from typing import Literal, TypeAlias
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.logger.base import LoggerManager
|
||||
from tianshou.utils.logger.pandas_logger import PandasLogger
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
@ -78,38 +77,3 @@ class LoggerFactoryDefault(LoggerFactory):
|
||||
return TensorboardLogger(writer)
|
||||
case _:
|
||||
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||
|
||||
|
||||
class LoggerManagerFactory(LoggerFactory):
|
||||
def __init__(
|
||||
self,
|
||||
logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"],
|
||||
wandb_project: str | None = None,
|
||||
):
|
||||
self.logger_types = logger_types
|
||||
self.wandb_project = wandb_project
|
||||
|
||||
self.factories = {
|
||||
"pandas": LoggerFactoryDefault(logger_type="pandas"),
|
||||
"wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project),
|
||||
"tensorboard": LoggerFactoryDefault(logger_type="tensorboard"),
|
||||
}
|
||||
|
||||
def create_logger(
|
||||
self,
|
||||
log_dir: str,
|
||||
experiment_name: str,
|
||||
run_id: str | None,
|
||||
config_dict: dict,
|
||||
) -> TLogger:
|
||||
logger_manager = LoggerManager()
|
||||
for logger_type in self.logger_types:
|
||||
logger_manager.loggers.append(
|
||||
self.factories[logger_type].create_logger(
|
||||
log_dir,
|
||||
experiment_name,
|
||||
run_id,
|
||||
config_dict,
|
||||
)
|
||||
)
|
||||
return logger_manager
|
||||
|
@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from numbers import Number
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -60,7 +59,7 @@ class BaseLogger(ABC):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def prepare_dict_for_logging(log_data: dict) -> dict:
|
||||
def prepare_dict_for_logging(log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||
return log_data
|
||||
|
||||
def log_train_data(self, log_data: dict, step: int) -> None:
|
||||
@ -72,7 +71,7 @@ class BaseLogger(ABC):
|
||||
# TODO: move interval check to calling method
|
||||
if step - self.last_log_train_step >= self.train_interval:
|
||||
log_data = self.prepare_dict_for_logging(log_data)
|
||||
self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data)
|
||||
self.write(f"{DataScope.TRAIN.value}/env_step", step, log_data)
|
||||
self.last_log_train_step = step
|
||||
|
||||
def log_test_data(self, log_data: dict, step: int) -> None:
|
||||
@ -84,7 +83,7 @@ class BaseLogger(ABC):
|
||||
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
|
||||
if step - self.last_log_test_step >= self.test_interval:
|
||||
log_data = self.prepare_dict_for_logging(log_data)
|
||||
self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data)
|
||||
self.write(f"{DataScope.TEST.value}/env_step", step, log_data)
|
||||
self.last_log_test_step = step
|
||||
|
||||
def log_update_data(self, log_data: dict, step: int) -> None:
|
||||
@ -96,7 +95,7 @@ class BaseLogger(ABC):
|
||||
# TODO: move interval check to calling method
|
||||
if step - self.last_log_update_step >= self.update_interval:
|
||||
log_data = self.prepare_dict_for_logging(log_data)
|
||||
self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data)
|
||||
self.write(f"{DataScope.UPDATE.value}/gradient_step", step, log_data)
|
||||
self.last_log_update_step = step
|
||||
|
||||
def log_info_data(self, log_data: dict, step: int) -> None:
|
||||
@ -109,7 +108,7 @@ class BaseLogger(ABC):
|
||||
step - self.last_log_info_step >= self.info_interval
|
||||
): # TODO: move interval check to calling method
|
||||
log_data = self.prepare_dict_for_logging(log_data)
|
||||
self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data)
|
||||
self.write(f"{DataScope.INFO.value}/epoch", step, log_data)
|
||||
self.last_log_info_step = step
|
||||
|
||||
@abstractmethod
|
||||
@ -139,10 +138,12 @@ class BaseLogger(ABC):
|
||||
:return: epoch, env_step, gradient_step.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def restore_logged_data(log_path):
|
||||
"""Load the logged data from dist for post-processing."""
|
||||
def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||
"""Load the logged data from disk for post-processing.
|
||||
|
||||
:return: a dict containing the logged data.
|
||||
"""
|
||||
|
||||
|
||||
class LazyLogger(BaseLogger):
|
||||
@ -166,57 +167,5 @@ class LazyLogger(BaseLogger):
|
||||
def restore_data(self) -> tuple[int, int, int]:
|
||||
return 0, 0, 0
|
||||
|
||||
def restore_logged_data(self):
|
||||
return None
|
||||
|
||||
|
||||
class LoggerManager(BaseLogger):
|
||||
"""A container of loggers that holds more than one logger."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loggers = []
|
||||
|
||||
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
|
||||
for logger in self.loggers:
|
||||
data_copy = data.copy()
|
||||
logger.write(step_type, step, data_copy)
|
||||
|
||||
def log_train_data(self, log_data: dict, step: int) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.log_train_data(log_data, step)
|
||||
|
||||
def log_test_data(self, log_data: dict, step: int) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.log_test_data(log_data, step)
|
||||
|
||||
def log_update_data(self, log_data: dict, step: int) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.log_update_data(log_data, step)
|
||||
|
||||
def log_info_data(self, log_data: dict, step: int) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.log_info_data(log_data, step)
|
||||
|
||||
def save_data(
|
||||
self,
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
|
||||
) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
|
||||
def restore_data(self) -> tuple[int, int, int]:
|
||||
for logger in self.loggers:
|
||||
epoch, env_step, gradient_step = logger.restore_data()
|
||||
|
||||
self.last_save_step = self.last_log_test_step = epoch
|
||||
self.last_log_update_step = gradient_step
|
||||
self.last_log_train_step = env_step
|
||||
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
def restore_logged_data(self, log_path):
|
||||
return self.loggers[0].restore_logged_data(log_path)
|
||||
def restore_logged_data(self, log_path: str) -> dict:
|
||||
return {}
|
||||
|
@ -1,14 +1,12 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Callable, Any
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from tianshou.utils import BaseLogger, logging
|
||||
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE
|
||||
from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE
|
||||
|
||||
|
||||
class PandasLogger(BaseLogger):
|
||||
@ -21,20 +19,47 @@ class PandasLogger(BaseLogger):
|
||||
info_interval: int = 1,
|
||||
exclude_arrays: bool = True,
|
||||
) -> None:
|
||||
super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays)
|
||||
super().__init__(
|
||||
train_interval,
|
||||
test_interval,
|
||||
update_interval,
|
||||
info_interval,
|
||||
exclude_arrays,
|
||||
)
|
||||
self.log_path = log_dir
|
||||
self.csv_name = os.path.join(self.log_path, "log.csv")
|
||||
self.pkl_name = os.path.join(self.log_path, "log.pkl")
|
||||
self.data = defaultdict(list)
|
||||
self.data: dict[str, list] = defaultdict(list)
|
||||
self.last_save_step = -1
|
||||
|
||||
@staticmethod
|
||||
def prepare_dict_for_logging(data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||
"""Removes invalid data types from the log data."""
|
||||
filtered_dict = data.copy()
|
||||
|
||||
def filter_dict(data_dict: dict[str, Any]) -> None:
|
||||
"""Filter in place."""
|
||||
for key, value in data_dict.items():
|
||||
if isinstance(value, VALID_LOG_VALS):
|
||||
filter_dict(value)
|
||||
else:
|
||||
filtered_dict.pop(key)
|
||||
|
||||
filter_dict(data)
|
||||
return filtered_dict
|
||||
|
||||
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
|
||||
scope, step_name = step_type.split("/")
|
||||
data[step_name] = step
|
||||
self.data[scope].append(data)
|
||||
|
||||
def save_data(self, epoch: int, env_step: int, gradient_step: int,
|
||||
save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None:
|
||||
def save_data(
|
||||
self,
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
|
||||
) -> None:
|
||||
self.last_save_step = epoch
|
||||
# create and dump a dataframe
|
||||
for k, v in self.data.items():
|
||||
@ -45,7 +70,13 @@ class PandasLogger(BaseLogger):
|
||||
def restore_data(self) -> tuple[int, int, int]:
|
||||
for scope in ["train", "test", "update", "info"]:
|
||||
try:
|
||||
self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values()))
|
||||
self.data[scope].extend(
|
||||
list(
|
||||
pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl"))
|
||||
.T.to_dict()
|
||||
.values(),
|
||||
),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Failed to restore {scope} data")
|
||||
|
||||
@ -67,12 +98,11 @@ class PandasLogger(BaseLogger):
|
||||
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
@staticmethod
|
||||
def restore_logged_data(log_path):
|
||||
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
|
||||
data = {}
|
||||
|
||||
def merge_dicts(list_of_dicts):
|
||||
result = defaultdict(list)
|
||||
def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
|
||||
result: dict[str, Any] = defaultdict(list)
|
||||
for d in list_of_dicts:
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
@ -85,22 +115,10 @@ class PandasLogger(BaseLogger):
|
||||
|
||||
for scope in ["train", "test", "update", "info"]:
|
||||
try:
|
||||
dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values())
|
||||
dict_list = list(
|
||||
pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values(),
|
||||
)
|
||||
data[scope] = merge_dicts(dict_list)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Failed to restore {scope} data")
|
||||
return data
|
||||
|
||||
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||
"""Filter out matplotlib figures from the data."""
|
||||
filtered_dict = data.copy()
|
||||
|
||||
def filter_dict(d):
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
filter_dict(value)
|
||||
elif isinstance(value, Figure):
|
||||
filtered_dict.pop(key)
|
||||
|
||||
filter_dict(data)
|
||||
return filtered_dict
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
@ -7,7 +6,7 @@ from matplotlib.figure import Figure
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS
|
||||
from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE, BaseLogger
|
||||
from tianshou.utils.warning import deprecation
|
||||
|
||||
|
||||
@ -86,7 +85,7 @@ class TensorboardLogger(BaseLogger):
|
||||
scope, step_name = step_type.split("/")
|
||||
self.writer.add_scalar(step_type, step, global_step=step)
|
||||
for k, v in data.items():
|
||||
scope_key = '/'.join([scope, k])
|
||||
scope_key = f"{scope}/{k}"
|
||||
if isinstance(v, np.ndarray):
|
||||
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto")
|
||||
elif isinstance(v, Figure):
|
||||
@ -133,20 +132,19 @@ class TensorboardLogger(BaseLogger):
|
||||
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
@staticmethod
|
||||
def restore_logged_data(log_path):
|
||||
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
|
||||
ea = event_accumulator.EventAccumulator(log_path)
|
||||
ea.Reload()
|
||||
|
||||
def add_to_dict(dictionary, keys, value):
|
||||
current_dict = dictionary
|
||||
def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
|
||||
current_dict = data_dict
|
||||
for key in keys[:-1]:
|
||||
current_dict = current_dict.setdefault(key, {})
|
||||
current_dict[keys[-1]] = value
|
||||
|
||||
data = {}
|
||||
data: dict[str, Any] = {}
|
||||
for key in ea.scalars.Keys():
|
||||
split_keys = key.split('/')
|
||||
split_keys = key.split("/")
|
||||
add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)]))
|
||||
|
||||
return data
|
||||
|
@ -156,3 +156,7 @@ class WandbLogger(BaseLogger):
|
||||
except KeyError:
|
||||
env_step = 0
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
def restore_logged_data(self, log_path: str) -> dict:
|
||||
assert self.tensorboard_logger is not None
|
||||
return self.tensorboard_logger.restore_logged_data(log_path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user