format, type check and small fixes

This commit is contained in:
Maximilian Huettenrauch 2024-03-12 15:01:50 +01:00
parent f730782f29
commit d9a612a997
21 changed files with 225 additions and 195 deletions

View File

@ -66,7 +66,13 @@ def main(
replay_buffer_save_only_last_obs=True, replay_buffer_save_only_last_obs=True,
) )
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs) env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
builder = ( builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config) DQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -65,7 +65,13 @@ def main(
replay_buffer_save_only_last_obs=True, replay_buffer_save_only_last_obs=True,
) )
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs) env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
experiment = ( experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config) IQNExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -65,7 +65,13 @@ def main(
replay_buffer_save_only_last_obs=True, replay_buffer_save_only_last_obs=True,
) )
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs) env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
builder = ( builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -54,7 +54,12 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = ( experiment = (
A2CExperimentBuilder(env_factory, experiment_config, sampling_config) A2CExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -51,7 +51,12 @@ def main(
start_timesteps_random=True, start_timesteps_random=True,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)
experiment = ( experiment = (
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -56,7 +56,12 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = ( experiment = (
NPGExperimentBuilder(env_factory, experiment_config, sampling_config) NPGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -61,7 +61,12 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = ( experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config) PPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -2,13 +2,12 @@
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial
from typing import Literal from typing import Literal
import torch import torch
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from examples.mujoco.tools import eval_results, RLiableExperimentResult from examples.mujoco.tools import RLiableExperimentResult, eval_results
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
ExperimentConfig, ExperimentConfig,
@ -19,7 +18,6 @@ from tianshou.highlevel.params.dist_fn import (
) )
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag from tianshou.utils.logging import datetime_tag
@ -65,7 +63,12 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiments = ( experiments = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config) PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
@ -102,13 +105,13 @@ def main(
def eval_experiments(log_dir: str): def eval_experiments(log_dir: str):
results = RLiableExperimentResult.load_from_disk(log_dir, 'PPO', None) results = RLiableExperimentResult.load_from_disk(log_dir, "PPO", None)
eval_results(results) eval_results(results)
if __name__ == "__main__": if __name__ == "__main__":
# logging.run_cli(main) # logging.run_cli(main)
experiment_config = ExperimentConfig(watch=False) # experiment_config = ExperimentConfig(watch=False)
log_dir = logging.run_main(partial(main, experiment_config, epoch=2)) # log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
# log_dir = <path/to/exp> log_dir = "log/Ant-v4/ppo/20240312-114646"
eval_experiments(log_dir) eval_experiments(log_dir)

View File

@ -57,7 +57,12 @@ def main(
start_timesteps_random=True, start_timesteps_random=True,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)
experiment = ( experiment = (
REDQExperimentBuilder(env_factory, experiment_config, sampling_config) REDQExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -49,7 +49,12 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = ( experiment = (
PGExperimentBuilder(env_factory, experiment_config, sampling_config) PGExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -52,7 +52,12 @@ def main(
start_timesteps_random=True, start_timesteps_random=True,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)
experiment = ( experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config) SACExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -58,7 +58,12 @@ def main(
start_timesteps_random=True, start_timesteps_random=True,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)
experiment = ( experiment = (
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -58,7 +58,12 @@ def main(
repeat_per_collect=repeat_per_collect, repeat_per_collect=repeat_per_collect,
) )
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)
experiment = ( experiment = (
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)

View File

@ -5,7 +5,7 @@ import csv
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, asdict from dataclasses import asdict, dataclass
import numpy as np import numpy as np
import tqdm import tqdm
@ -34,26 +34,33 @@ class RLiableExperimentResult:
test_episode_returns = [] test_episode_returns = []
for entry in os.scandir(exp_dir): for entry in os.scandir(exp_dir):
if entry.name.startswith('.'): if entry.name.startswith(".") or not entry.is_dir():
continue continue
exp = Experiment.from_directory(entry.path) exp = Experiment.from_directory(entry.path)
logger = exp.logger_factory.create_logger(entry.path, entry.name, None, asdict(exp.config)) logger = exp.logger_factory.create_logger(
entry.path,
entry.name,
None,
asdict(exp.config),
)
data = logger.restore_logged_data(entry.path) data = logger.restore_logged_data(entry.path)
test_data = data['test'] test_data = data["test"]
test_episode_returns.append(test_data['returns_stat']['mean']) test_episode_returns.append(test_data["returns_stat"]["mean"])
env_step = test_data['env_step'] env_step = test_data["env_step"]
if score_thresholds is None: if score_thresholds is None:
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101) score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
return RLiableExperimentResult(algorithms=[algo_name], return RLiableExperimentResult(
algorithms=[algo_name],
score_dict={algo_name: np.array(test_episode_returns)}, score_dict={algo_name: np.array(test_episode_returns)},
env_steps=np.array(env_step), env_steps=np.array(env_step),
score_thresholds=score_thresholds, score_thresholds=score_thresholds,
exp_dir=exp_dir) exp_dir=exp_dir,
)
def eval_results(results: RLiableExperimentResult): def eval_results(results: RLiableExperimentResult):
@ -64,31 +71,44 @@ def eval_results(results: RLiableExperimentResult):
from rliable import plot_utils from rliable import plot_utils
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates( iqm_scores, iqm_cis = rly.get_interval_estimates(results.score_dict, iqm, reps=50000)
results.score_dict, iqm, reps=50000)
# Plot IQM sample efficiency curve # Plot IQM sample efficiency curve
fig, ax = plt.subplots(ncols=1, figsize=(7, 5)) fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_sample_efficiency_curve( plot_utils.plot_sample_efficiency_curve(
results.env_steps, iqm_scores, iqm_cis, algorithms=results.algorithms, results.env_steps,
xlabel=r'Number of env steps', iqm_scores,
ylabel='IQM episode return', iqm_cis,
ax=ax) algorithms=results.algorithms,
plt.savefig(os.path.join(results.exp_dir, 'iqm_sample_efficiency_curve.png')) xlabel=r"Number of env steps",
ylabel="IQM episode return",
ax=ax,
)
plt.savefig(os.path.join(results.exp_dir, "iqm_sample_efficiency_curve.png"))
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()} final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
score_distributions, score_distributions_cis = rly.create_performance_profile( score_distributions, score_distributions_cis = rly.create_performance_profile(
final_score_dict, results.score_thresholds) final_score_dict,
results.score_thresholds,
)
# Plot score distributions # Plot score distributions
fig, ax = plt.subplots(ncols=1, figsize=(7, 5)) fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
plot_utils.plot_performance_profiles( plot_utils.plot_performance_profiles(
score_distributions, results.score_thresholds, score_distributions,
results.score_thresholds,
performance_profile_cis=score_distributions_cis, performance_profile_cis=score_distributions_cis,
colors=dict(zip(results.algorithms, sns.color_palette('colorblind'))), colors=dict(
xlabel=r'Episode return $(\tau)$', zip(
ax=ax) results.algorithms,
plt.savefig(os.path.join(results.exp_dir, 'performance_profile.png')) sns.color_palette("colorblind", n_colors=len(results.algorithms)),
strict=True,
),
),
xlabel=r"Episode return $(\tau)$",
ax=ax,
)
plt.savefig(os.path.join(results.exp_dir, "performance_profile.png"))
def find_all_files(root_dir, pattern): def find_all_files(root_dir, pattern):

View File

@ -6,9 +6,19 @@ from tianshou.highlevel.env import (
class DiscreteTestEnvFactory(EnvFactoryRegistered): class DiscreteTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(task="CartPole-v0", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY) super().__init__(
task="CartPole-v0",
train_seed=42,
test_seed=1337,
venv_type=VectorEnvType.DUMMY,
)
class ContinuousTestEnvFactory(EnvFactoryRegistered): class ContinuousTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(task="Pendulum-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY) super().__init__(
task="Pendulum-v1",
train_seed=42,
test_seed=1337,
venv_type=VectorEnvType.DUMMY,
)

View File

@ -3,7 +3,7 @@ import pickle
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from copy import copy from copy import copy
from dataclasses import dataclass, asdict from dataclasses import asdict, dataclass
from pprint import pformat from pprint import pformat
from typing import Self from typing import Self
@ -87,9 +87,7 @@ from tianshou.utils.string import ToStringMixin
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def shortener(input_string: str | None = None, def shortener(input_string: str | None = None, length: int = 1) -> str:
length: int = 1
):
"""Shorten the input string by keeping only the first `length` characters of each word. """Shorten the input string by keeping only the first `length` characters of each word.
If the input string is None or empty, return "default". If the input string is None or empty, return "default".
@ -367,19 +365,19 @@ class ExperimentBuilder:
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
@property @property
def experiment_config(self): def experiment_config(self) -> ExperimentConfig:
return self._config return self._config
@experiment_config.setter @experiment_config.setter
def experiment_config(self, experiment_config: ExperimentConfig): def experiment_config(self, experiment_config: ExperimentConfig) -> None:
self._config = experiment_config self._config = experiment_config
@property @property
def sampling_config(self): def sampling_config(self) -> SamplingConfig:
return self._sampling_config return self._sampling_config
@sampling_config.setter @sampling_config.setter
def sampling_config(self, sampling_config: SamplingConfig): def sampling_config(self, sampling_config: SamplingConfig) -> None:
self._sampling_config = sampling_config self._sampling_config = sampling_config
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
@ -491,7 +489,6 @@ class ExperimentBuilder:
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments. The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
""" """
configured_experiment_config = copy(self.experiment_config) configured_experiment_config = copy(self.experiment_config)
configured_experiment_seed = configured_experiment_config.seed configured_experiment_seed = configured_experiment_config.seed
configured_sampling_config = copy(self.sampling_config) configured_sampling_config = copy(self.sampling_config)
@ -512,9 +509,13 @@ class ExperimentBuilder:
self.sampling_config = SamplingConfig(**new_sampling_config_dict) self.sampling_config = SamplingConfig(**new_sampling_config_dict)
exp = self.build() exp = self.build()
full_name = ",".join([f"experiment_seed={exp.config.seed}", full_name = ",".join(
[
f"experiment_seed={exp.config.seed}",
f"train_seed={exp.sampling_config.train_seed}", f"train_seed={exp.sampling_config.train_seed}",
f"test_seed={exp.sampling_config.test_seed}"]) f"test_seed={exp.sampling_config.test_seed}",
],
)
experiment_name = shortener(full_name, 4) experiment_name = shortener(full_name, 4)
seeded_experiments[experiment_name] = exp seeded_experiments[experiment_name] = exp

View File

@ -5,7 +5,6 @@ from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
from tianshou.utils.logger.base import LoggerManager
from tianshou.utils.logger.pandas_logger import PandasLogger from tianshou.utils.logger.pandas_logger import PandasLogger
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -78,38 +77,3 @@ class LoggerFactoryDefault(LoggerFactory):
return TensorboardLogger(writer) return TensorboardLogger(writer)
case _: case _:
raise ValueError(f"Unknown logger type '{self.logger_type}'") raise ValueError(f"Unknown logger type '{self.logger_type}'")
class LoggerManagerFactory(LoggerFactory):
def __init__(
self,
logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"],
wandb_project: str | None = None,
):
self.logger_types = logger_types
self.wandb_project = wandb_project
self.factories = {
"pandas": LoggerFactoryDefault(logger_type="pandas"),
"wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project),
"tensorboard": LoggerFactoryDefault(logger_type="tensorboard"),
}
def create_logger(
self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger:
logger_manager = LoggerManager()
for logger_type in self.logger_types:
logger_manager.loggers.append(
self.factories[logger_type].create_logger(
log_dir,
experiment_name,
run_id,
config_dict,
)
)
return logger_manager

View File

@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from enum import Enum from enum import Enum
from numbers import Number from numbers import Number
from typing import Any
import numpy as np import numpy as np
@ -60,7 +59,7 @@ class BaseLogger(ABC):
""" """
@staticmethod @staticmethod
def prepare_dict_for_logging(log_data: dict) -> dict: def prepare_dict_for_logging(log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
return log_data return log_data
def log_train_data(self, log_data: dict, step: int) -> None: def log_train_data(self, log_data: dict, step: int) -> None:
@ -72,7 +71,7 @@ class BaseLogger(ABC):
# TODO: move interval check to calling method # TODO: move interval check to calling method
if step - self.last_log_train_step >= self.train_interval: if step - self.last_log_train_step >= self.train_interval:
log_data = self.prepare_dict_for_logging(log_data) log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data) self.write(f"{DataScope.TRAIN.value}/env_step", step, log_data)
self.last_log_train_step = step self.last_log_train_step = step
def log_test_data(self, log_data: dict, step: int) -> None: def log_test_data(self, log_data: dict, step: int) -> None:
@ -84,7 +83,7 @@ class BaseLogger(ABC):
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
if step - self.last_log_test_step >= self.test_interval: if step - self.last_log_test_step >= self.test_interval:
log_data = self.prepare_dict_for_logging(log_data) log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data) self.write(f"{DataScope.TEST.value}/env_step", step, log_data)
self.last_log_test_step = step self.last_log_test_step = step
def log_update_data(self, log_data: dict, step: int) -> None: def log_update_data(self, log_data: dict, step: int) -> None:
@ -96,7 +95,7 @@ class BaseLogger(ABC):
# TODO: move interval check to calling method # TODO: move interval check to calling method
if step - self.last_log_update_step >= self.update_interval: if step - self.last_log_update_step >= self.update_interval:
log_data = self.prepare_dict_for_logging(log_data) log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data) self.write(f"{DataScope.UPDATE.value}/gradient_step", step, log_data)
self.last_log_update_step = step self.last_log_update_step = step
def log_info_data(self, log_data: dict, step: int) -> None: def log_info_data(self, log_data: dict, step: int) -> None:
@ -109,7 +108,7 @@ class BaseLogger(ABC):
step - self.last_log_info_step >= self.info_interval step - self.last_log_info_step >= self.info_interval
): # TODO: move interval check to calling method ): # TODO: move interval check to calling method
log_data = self.prepare_dict_for_logging(log_data) log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data) self.write(f"{DataScope.INFO.value}/epoch", step, log_data)
self.last_log_info_step = step self.last_log_info_step = step
@abstractmethod @abstractmethod
@ -139,10 +138,12 @@ class BaseLogger(ABC):
:return: epoch, env_step, gradient_step. :return: epoch, env_step, gradient_step.
""" """
@staticmethod
@abstractmethod @abstractmethod
def restore_logged_data(log_path): def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Load the logged data from dist for post-processing.""" """Load the logged data from disk for post-processing.
:return: a dict containing the logged data.
"""
class LazyLogger(BaseLogger): class LazyLogger(BaseLogger):
@ -166,57 +167,5 @@ class LazyLogger(BaseLogger):
def restore_data(self) -> tuple[int, int, int]: def restore_data(self) -> tuple[int, int, int]:
return 0, 0, 0 return 0, 0, 0
def restore_logged_data(self): def restore_logged_data(self, log_path: str) -> dict:
return None return {}
class LoggerManager(BaseLogger):
"""A container of loggers that holds more than one logger."""
def __init__(self) -> None:
super().__init__()
self.loggers = []
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
for logger in self.loggers:
data_copy = data.copy()
logger.write(step_type, step, data_copy)
def log_train_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_train_data(log_data, step)
def log_test_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_test_data(log_data, step)
def log_update_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_update_data(log_data, step)
def log_info_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_info_data(log_data, step)
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
for logger in self.loggers:
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
def restore_data(self) -> tuple[int, int, int]:
for logger in self.loggers:
epoch, env_step, gradient_step = logger.restore_data()
self.last_save_step = self.last_log_test_step = epoch
self.last_log_update_step = gradient_step
self.last_log_train_step = env_step
return epoch, env_step, gradient_step
def restore_logged_data(self, log_path):
return self.loggers[0].restore_logged_data(log_path)

View File

@ -1,14 +1,12 @@
import os import os
from collections import defaultdict from collections import defaultdict
from functools import partial from collections.abc import Callable
from typing import Callable, Any from typing import Any
import numpy as np
import pandas as pd import pandas as pd
from matplotlib.figure import Figure
from tianshou.utils import BaseLogger, logging from tianshou.utils import BaseLogger, logging
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE
class PandasLogger(BaseLogger): class PandasLogger(BaseLogger):
@ -21,20 +19,47 @@ class PandasLogger(BaseLogger):
info_interval: int = 1, info_interval: int = 1,
exclude_arrays: bool = True, exclude_arrays: bool = True,
) -> None: ) -> None:
super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays) super().__init__(
train_interval,
test_interval,
update_interval,
info_interval,
exclude_arrays,
)
self.log_path = log_dir self.log_path = log_dir
self.csv_name = os.path.join(self.log_path, "log.csv") self.csv_name = os.path.join(self.log_path, "log.csv")
self.pkl_name = os.path.join(self.log_path, "log.pkl") self.pkl_name = os.path.join(self.log_path, "log.pkl")
self.data = defaultdict(list) self.data: dict[str, list] = defaultdict(list)
self.last_save_step = -1 self.last_save_step = -1
@staticmethod
def prepare_dict_for_logging(data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Removes invalid data types from the log data."""
filtered_dict = data.copy()
def filter_dict(data_dict: dict[str, Any]) -> None:
"""Filter in place."""
for key, value in data_dict.items():
if isinstance(value, VALID_LOG_VALS):
filter_dict(value)
else:
filtered_dict.pop(key)
filter_dict(data)
return filtered_dict
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
scope, step_name = step_type.split("/") scope, step_name = step_type.split("/")
data[step_name] = step data[step_name] = step
self.data[scope].append(data) self.data[scope].append(data)
def save_data(self, epoch: int, env_step: int, gradient_step: int, def save_data(
save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None: self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
self.last_save_step = epoch self.last_save_step = epoch
# create and dump a dataframe # create and dump a dataframe
for k, v in self.data.items(): for k, v in self.data.items():
@ -45,7 +70,13 @@ class PandasLogger(BaseLogger):
def restore_data(self) -> tuple[int, int, int]: def restore_data(self) -> tuple[int, int, int]:
for scope in ["train", "test", "update", "info"]: for scope in ["train", "test", "update", "info"]:
try: try:
self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values())) self.data[scope].extend(
list(
pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl"))
.T.to_dict()
.values(),
),
)
except FileNotFoundError: except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data") logging.warning(f"Failed to restore {scope} data")
@ -67,12 +98,11 @@ class PandasLogger(BaseLogger):
return epoch, env_step, gradient_step return epoch, env_step, gradient_step
@staticmethod def restore_logged_data(self, log_path: str) -> dict[str, Any]:
def restore_logged_data(log_path):
data = {} data = {}
def merge_dicts(list_of_dicts): def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
result = defaultdict(list) result: dict[str, Any] = defaultdict(list)
for d in list_of_dicts: for d in list_of_dicts:
for key, value in d.items(): for key, value in d.items():
if isinstance(value, dict): if isinstance(value, dict):
@ -85,22 +115,10 @@ class PandasLogger(BaseLogger):
for scope in ["train", "test", "update", "info"]: for scope in ["train", "test", "update", "info"]:
try: try:
dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values()) dict_list = list(
pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values(),
)
data[scope] = merge_dicts(dict_list) data[scope] = merge_dicts(dict_list)
except FileNotFoundError: except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data") logging.warning(f"Failed to restore {scope} data")
return data return data
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Filter out matplotlib figures from the data."""
filtered_dict = data.copy()
def filter_dict(d):
for key, value in d.items():
if isinstance(value, dict):
filter_dict(value)
elif isinstance(value, Figure):
filtered_dict.pop(key)
filter_dict(data)
return filtered_dict

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@ -7,7 +6,7 @@ from matplotlib.figure import Figure
from tensorboard.backend.event_processing import event_accumulator from tensorboard.backend.event_processing import event_accumulator
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE, BaseLogger
from tianshou.utils.warning import deprecation from tianshou.utils.warning import deprecation
@ -86,7 +85,7 @@ class TensorboardLogger(BaseLogger):
scope, step_name = step_type.split("/") scope, step_name = step_type.split("/")
self.writer.add_scalar(step_type, step, global_step=step) self.writer.add_scalar(step_type, step, global_step=step)
for k, v in data.items(): for k, v in data.items():
scope_key = '/'.join([scope, k]) scope_key = f"{scope}/{k}"
if isinstance(v, np.ndarray): if isinstance(v, np.ndarray):
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto") self.writer.add_histogram(scope_key, v, global_step=step, bins="auto")
elif isinstance(v, Figure): elif isinstance(v, Figure):
@ -133,20 +132,19 @@ class TensorboardLogger(BaseLogger):
return epoch, env_step, gradient_step return epoch, env_step, gradient_step
@staticmethod def restore_logged_data(self, log_path: str) -> dict[str, Any]:
def restore_logged_data(log_path):
ea = event_accumulator.EventAccumulator(log_path) ea = event_accumulator.EventAccumulator(log_path)
ea.Reload() ea.Reload()
def add_to_dict(dictionary, keys, value): def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
current_dict = dictionary current_dict = data_dict
for key in keys[:-1]: for key in keys[:-1]:
current_dict = current_dict.setdefault(key, {}) current_dict = current_dict.setdefault(key, {})
current_dict[keys[-1]] = value current_dict[keys[-1]] = value
data = {} data: dict[str, Any] = {}
for key in ea.scalars.Keys(): for key in ea.scalars.Keys():
split_keys = key.split('/') split_keys = key.split("/")
add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)])) add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)]))
return data return data

View File

@ -156,3 +156,7 @@ class WandbLogger(BaseLogger):
except KeyError: except KeyError:
env_step = 0 env_step = 0
return epoch, env_step, gradient_step return epoch, env_step, gradient_step
def restore_logged_data(self, log_path: str) -> dict:
assert self.tensorboard_logger is not None
return self.tensorboard_logger.restore_logged_data(log_path)