format, type check and small fixes

This commit is contained in:
Maximilian Huettenrauch 2024-03-12 15:01:50 +01:00
parent f730782f29
commit d9a612a997
21 changed files with 225 additions and 195 deletions

View File

@ -66,7 +66,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -65,7 +65,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -65,7 +65,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -54,7 +54,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = (
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -51,7 +51,12 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)
experiment = (
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -56,7 +56,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = (
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -61,7 +61,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -2,13 +2,12 @@
import os
from collections.abc import Sequence
from functools import partial
from typing import Literal
import torch
from examples.mujoco.mujoco_env import MujocoEnvFactory
from examples.mujoco.tools import eval_results, RLiableExperimentResult
from examples.mujoco.tools import RLiableExperimentResult, eval_results
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
@ -19,7 +18,6 @@ from tianshou.highlevel.params.dist_fn import (
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
@ -65,7 +63,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiments = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
@ -102,13 +105,13 @@ def main(
def eval_experiments(log_dir: str):
results = RLiableExperimentResult.load_from_disk(log_dir, 'PPO', None)
results = RLiableExperimentResult.load_from_disk(log_dir, "PPO", None)
eval_results(results)
if __name__ == "__main__":
# logging.run_cli(main)
experiment_config = ExperimentConfig(watch=False)
log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
# log_dir = <path/to/exp>
# experiment_config = ExperimentConfig(watch=False)
# log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
log_dir = "log/Ant-v4/ppo/20240312-114646"
eval_experiments(log_dir)

View File

@ -57,7 +57,12 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)
experiment = (
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -49,7 +49,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = (
PGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -52,7 +52,12 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)
experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -58,7 +58,12 @@ def main(
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)
experiment = (
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -58,7 +58,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = (
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -5,7 +5,7 @@ import csv
import os
import re
from collections import defaultdict
from dataclasses import dataclass, asdict
from dataclasses import asdict, dataclass
import numpy as np
import tqdm
@ -34,26 +34,33 @@ class RLiableExperimentResult:
test_episode_returns = []
for entry in os.scandir(exp_dir):
if entry.name.startswith('.'):
if entry.name.startswith(".") or not entry.is_dir():
continue
exp = Experiment.from_directory(entry.path)
logger = exp.logger_factory.create_logger(entry.path, entry.name, None, asdict(exp.config))
logger = exp.logger_factory.create_logger(
entry.path,
entry.name,
None,
asdict(exp.config),
)
data = logger.restore_logged_data(entry.path)
test_data = data['test']
test_data = data["test"]
test_episode_returns.append(test_data['returns_stat']['mean'])
env_step = test_data['env_step']
test_episode_returns.append(test_data["returns_stat"]["mean"])
env_step = test_data["env_step"]
if score_thresholds is None:
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
return RLiableExperimentResult(algorithms=[algo_name],
score_dict={algo_name: np.array(test_episode_returns)},
env_steps=np.array(env_step),
score_thresholds=score_thresholds,
exp_dir=exp_dir)
return RLiableExperimentResult(
algorithms=[algo_name],
score_dict={algo_name: np.array(test_episode_returns)},
env_steps=np.array(env_step),
score_thresholds=score_thresholds,
exp_dir=exp_dir,
)
def eval_results(results: RLiableExperimentResult):
@ -64,31 +71,44 @@ def eval_results(results: RLiableExperimentResult):
from rliable import plot_utils
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates(
results.score_dict, iqm, reps=50000)
iqm_scores, iqm_cis = rly.get_interval_estimates(results.score_dict, iqm, reps=50000)
# Plot IQM sample efficiency curve
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_sample_efficiency_curve(
results.env_steps, iqm_scores, iqm_cis, algorithms=results.algorithms,
xlabel=r'Number of env steps',
ylabel='IQM episode return',
ax=ax)
plt.savefig(os.path.join(results.exp_dir, 'iqm_sample_efficiency_curve.png'))
results.env_steps,
iqm_scores,
iqm_cis,
algorithms=results.algorithms,
xlabel=r"Number of env steps",
ylabel="IQM episode return",
ax=ax,
)
plt.savefig(os.path.join(results.exp_dir, "iqm_sample_efficiency_curve.png"))
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
score_distributions, score_distributions_cis = rly.create_performance_profile(
final_score_dict, results.score_thresholds)
final_score_dict,
results.score_thresholds,
)
# Plot score distributions
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_performance_profiles(
score_distributions, results.score_thresholds,
score_distributions,
results.score_thresholds,
performance_profile_cis=score_distributions_cis,
colors=dict(zip(results.algorithms, sns.color_palette('colorblind'))),
xlabel=r'Episode return $(\tau)$',
ax=ax)
plt.savefig(os.path.join(results.exp_dir, 'performance_profile.png'))
colors=dict(
zip(
results.algorithms,
sns.color_palette("colorblind", n_colors=len(results.algorithms)),
strict=True,
),
),
xlabel=r"Episode return $(\tau)$",
ax=ax,
)
plt.savefig(os.path.join(results.exp_dir, "performance_profile.png"))
def find_all_files(root_dir, pattern):

View File

@ -6,9 +6,19 @@ from tianshou.highlevel.env import (
class DiscreteTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None:
super().__init__(task="CartPole-v0", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)
super().__init__(
task="CartPole-v0",
train_seed=42,
test_seed=1337,
venv_type=VectorEnvType.DUMMY,
)
class ContinuousTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None:
super().__init__(task="Pendulum-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)
super().__init__(
task="Pendulum-v1",
train_seed=42,
test_seed=1337,
venv_type=VectorEnvType.DUMMY,
)

View File

@ -3,7 +3,7 @@ import pickle
from abc import abstractmethod
from collections.abc import Sequence
from copy import copy
from dataclasses import dataclass, asdict
from dataclasses import asdict, dataclass
from pprint import pformat
from typing import Self
@ -87,9 +87,7 @@ from tianshou.utils.string import ToStringMixin
log = logging.getLogger(__name__)
def shortener(input_string: str | None = None,
length: int = 1
):
def shortener(input_string: str | None = None, length: int = 1) -> str:
"""Shorten the input string by keeping only the first `length` characters of each word.
If the input string is None or empty, return "default".
@ -367,19 +365,19 @@ class ExperimentBuilder:
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
@property
def experiment_config(self):
def experiment_config(self) -> ExperimentConfig:
return self._config
@experiment_config.setter
def experiment_config(self, experiment_config: ExperimentConfig):
def experiment_config(self, experiment_config: ExperimentConfig) -> None:
self._config = experiment_config
@property
def sampling_config(self):
def sampling_config(self) -> SamplingConfig:
return self._sampling_config
@sampling_config.setter
def sampling_config(self, sampling_config: SamplingConfig):
def sampling_config(self, sampling_config: SamplingConfig) -> None:
self._sampling_config = sampling_config
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
@ -491,7 +489,6 @@ class ExperimentBuilder:
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
"""
configured_experiment_config = copy(self.experiment_config)
configured_experiment_seed = configured_experiment_config.seed
configured_sampling_config = copy(self.sampling_config)
@ -512,9 +509,13 @@ class ExperimentBuilder:
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
exp = self.build()
full_name = ",".join([f"experiment_seed={exp.config.seed}",
f"train_seed={exp.sampling_config.train_seed}",
f"test_seed={exp.sampling_config.test_seed}"])
full_name = ",".join(
[
f"experiment_seed={exp.config.seed}",
f"train_seed={exp.sampling_config.train_seed}",
f"test_seed={exp.sampling_config.test_seed}",
],
)
experiment_name = shortener(full_name, 4)
seeded_experiments[experiment_name] = exp

View File

@ -5,7 +5,6 @@ from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
from tianshou.utils.logger.base import LoggerManager
from tianshou.utils.logger.pandas_logger import PandasLogger
from tianshou.utils.string import ToStringMixin
@ -78,38 +77,3 @@ class LoggerFactoryDefault(LoggerFactory):
return TensorboardLogger(writer)
case _:
raise ValueError(f"Unknown logger type '{self.logger_type}'")
class LoggerManagerFactory(LoggerFactory):
def __init__(
self,
logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"],
wandb_project: str | None = None,
):
self.logger_types = logger_types
self.wandb_project = wandb_project
self.factories = {
"pandas": LoggerFactoryDefault(logger_type="pandas"),
"wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project),
"tensorboard": LoggerFactoryDefault(logger_type="tensorboard"),
}
def create_logger(
self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger:
logger_manager = LoggerManager()
for logger_type in self.logger_types:
logger_manager.loggers.append(
self.factories[logger_type].create_logger(
log_dir,
experiment_name,
run_id,
config_dict,
)
)
return logger_manager

View File

@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum
from numbers import Number
from typing import Any
import numpy as np
@ -60,7 +59,7 @@ class BaseLogger(ABC):
"""
@staticmethod
def prepare_dict_for_logging(log_data: dict) -> dict:
def prepare_dict_for_logging(log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
return log_data
def log_train_data(self, log_data: dict, step: int) -> None:
@ -72,7 +71,7 @@ class BaseLogger(ABC):
# TODO: move interval check to calling method
if step - self.last_log_train_step >= self.train_interval:
log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data)
self.write(f"{DataScope.TRAIN.value}/env_step", step, log_data)
self.last_log_train_step = step
def log_test_data(self, log_data: dict, step: int) -> None:
@ -84,7 +83,7 @@ class BaseLogger(ABC):
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
if step - self.last_log_test_step >= self.test_interval:
log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data)
self.write(f"{DataScope.TEST.value}/env_step", step, log_data)
self.last_log_test_step = step
def log_update_data(self, log_data: dict, step: int) -> None:
@ -96,7 +95,7 @@ class BaseLogger(ABC):
# TODO: move interval check to calling method
if step - self.last_log_update_step >= self.update_interval:
log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data)
self.write(f"{DataScope.UPDATE.value}/gradient_step", step, log_data)
self.last_log_update_step = step
def log_info_data(self, log_data: dict, step: int) -> None:
@ -109,7 +108,7 @@ class BaseLogger(ABC):
step - self.last_log_info_step >= self.info_interval
): # TODO: move interval check to calling method
log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data)
self.write(f"{DataScope.INFO.value}/epoch", step, log_data)
self.last_log_info_step = step
@abstractmethod
@ -139,10 +138,12 @@ class BaseLogger(ABC):
:return: epoch, env_step, gradient_step.
"""
@staticmethod
@abstractmethod
def restore_logged_data(log_path):
"""Load the logged data from dist for post-processing."""
def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Load the logged data from disk for post-processing.
:return: a dict containing the logged data.
"""
class LazyLogger(BaseLogger):
@ -166,57 +167,5 @@ class LazyLogger(BaseLogger):
def restore_data(self) -> tuple[int, int, int]:
return 0, 0, 0
def restore_logged_data(self):
return None
class LoggerManager(BaseLogger):
"""A container of loggers that holds more than one logger."""
def __init__(self) -> None:
super().__init__()
self.loggers = []
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
for logger in self.loggers:
data_copy = data.copy()
logger.write(step_type, step, data_copy)
def log_train_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_train_data(log_data, step)
def log_test_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_test_data(log_data, step)
def log_update_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_update_data(log_data, step)
def log_info_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_info_data(log_data, step)
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
for logger in self.loggers:
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
def restore_data(self) -> tuple[int, int, int]:
for logger in self.loggers:
epoch, env_step, gradient_step = logger.restore_data()
self.last_save_step = self.last_log_test_step = epoch
self.last_log_update_step = gradient_step
self.last_log_train_step = env_step
return epoch, env_step, gradient_step
def restore_logged_data(self, log_path):
return self.loggers[0].restore_logged_data(log_path)
def restore_logged_data(self, log_path: str) -> dict:
return {}

View File

@ -1,14 +1,12 @@
import os
from collections import defaultdict
from functools import partial
from typing import Callable, Any
from collections.abc import Callable
from typing import Any
import numpy as np
import pandas as pd
from matplotlib.figure import Figure
from tianshou.utils import BaseLogger, logging
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE
from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE
class PandasLogger(BaseLogger):
@ -21,20 +19,47 @@ class PandasLogger(BaseLogger):
info_interval: int = 1,
exclude_arrays: bool = True,
) -> None:
super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays)
super().__init__(
train_interval,
test_interval,
update_interval,
info_interval,
exclude_arrays,
)
self.log_path = log_dir
self.csv_name = os.path.join(self.log_path, "log.csv")
self.pkl_name = os.path.join(self.log_path, "log.pkl")
self.data = defaultdict(list)
self.data: dict[str, list] = defaultdict(list)
self.last_save_step = -1
@staticmethod
def prepare_dict_for_logging(data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Removes invalid data types from the log data."""
filtered_dict = data.copy()
def filter_dict(data_dict: dict[str, Any]) -> None:
"""Filter in place."""
for key, value in data_dict.items():
if isinstance(value, VALID_LOG_VALS):
filter_dict(value)
else:
filtered_dict.pop(key)
filter_dict(data)
return filtered_dict
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
scope, step_name = step_type.split("/")
data[step_name] = step
self.data[scope].append(data)
def save_data(self, epoch: int, env_step: int, gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None:
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
self.last_save_step = epoch
# create and dump a dataframe
for k, v in self.data.items():
@ -45,7 +70,13 @@ class PandasLogger(BaseLogger):
def restore_data(self) -> tuple[int, int, int]:
for scope in ["train", "test", "update", "info"]:
try:
self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values()))
self.data[scope].extend(
list(
pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl"))
.T.to_dict()
.values(),
),
)
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
@ -67,12 +98,11 @@ class PandasLogger(BaseLogger):
return epoch, env_step, gradient_step
@staticmethod
def restore_logged_data(log_path):
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
data = {}
def merge_dicts(list_of_dicts):
result = defaultdict(list)
def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
result: dict[str, Any] = defaultdict(list)
for d in list_of_dicts:
for key, value in d.items():
if isinstance(value, dict):
@ -85,22 +115,10 @@ class PandasLogger(BaseLogger):
for scope in ["train", "test", "update", "info"]:
try:
dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values())
dict_list = list(
pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values(),
)
data[scope] = merge_dicts(dict_list)
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
return data
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Filter out matplotlib figures from the data."""
filtered_dict = data.copy()
def filter_dict(d):
for key, value in d.items():
if isinstance(value, dict):
filter_dict(value)
elif isinstance(value, Figure):
filtered_dict.pop(key)
filter_dict(data)
return filtered_dict

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from collections.abc import Callable
from typing import Any
@ -7,7 +6,7 @@ from matplotlib.figure import Figure
from tensorboard.backend.event_processing import event_accumulator
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS
from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE, BaseLogger
from tianshou.utils.warning import deprecation
@ -86,7 +85,7 @@ class TensorboardLogger(BaseLogger):
scope, step_name = step_type.split("/")
self.writer.add_scalar(step_type, step, global_step=step)
for k, v in data.items():
scope_key = '/'.join([scope, k])
scope_key = f"{scope}/{k}"
if isinstance(v, np.ndarray):
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto")
elif isinstance(v, Figure):
@ -133,20 +132,19 @@ class TensorboardLogger(BaseLogger):
return epoch, env_step, gradient_step
@staticmethod
def restore_logged_data(log_path):
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
ea = event_accumulator.EventAccumulator(log_path)
ea.Reload()
def add_to_dict(dictionary, keys, value):
current_dict = dictionary
def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
current_dict = data_dict
for key in keys[:-1]:
current_dict = current_dict.setdefault(key, {})
current_dict[keys[-1]] = value
data = {}
data: dict[str, Any] = {}
for key in ea.scalars.Keys():
split_keys = key.split('/')
split_keys = key.split("/")
add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)]))
return data

View File

@ -156,3 +156,7 @@ class WandbLogger(BaseLogger):
except KeyError:
env_step = 0
return epoch, env_step, gradient_step
def restore_logged_data(self, log_path: str) -> dict:
assert self.tensorboard_logger is not None
return self.tensorboard_logger.restore_logged_data(log_path)