diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index edba36e..890d948 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -66,7 +66,13 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs) + env_factory = AtariEnvFactory( + task, + sampling_config.train_seed, + sampling_config.test_seed, + frames_stack, + scale=scale_obs, + ) builder = ( DQNExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 01f4f34..3918550 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -65,7 +65,13 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs) + env_factory = AtariEnvFactory( + task, + sampling_config.train_seed, + sampling_config.test_seed, + frames_stack, + scale=scale_obs, + ) experiment = ( IQNExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 9bcaf05..35f2538 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -65,7 +65,13 @@ def main( replay_buffer_save_only_last_obs=True, ) - env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs) + env_factory = AtariEnvFactory( + task, + sampling_config.train_seed, + sampling_config.test_seed, + frames_stack, + scale=scale_obs, + ) builder = ( DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 23ace13..ef6ebd1 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -54,7 +54,12 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=True, + ) experiment = ( A2CExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 1a7abab..6ef3886 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -51,7 +51,12 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=False, + ) experiment = ( DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 21594d0..c430d86 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -56,7 +56,12 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=True, + ) experiment = ( NPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 898d448..ccd4b59 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -61,7 +61,12 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=True, + ) experiment = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 12185a5..34256e7 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -2,13 +2,12 @@ import os from collections.abc import Sequence -from functools import partial from typing import Literal import torch from examples.mujoco.mujoco_env import MujocoEnvFactory -from examples.mujoco.tools import eval_results, RLiableExperimentResult +from examples.mujoco.tools import RLiableExperimentResult, eval_results from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, @@ -19,7 +18,6 @@ from tianshou.highlevel.params.dist_fn import ( ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams -from tianshou.utils import logging from tianshou.utils.logging import datetime_tag @@ -65,7 +63,12 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=True, + ) experiments = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) @@ -102,13 +105,13 @@ def main( def eval_experiments(log_dir: str): - results = RLiableExperimentResult.load_from_disk(log_dir, 'PPO', None) + results = RLiableExperimentResult.load_from_disk(log_dir, "PPO", None) eval_results(results) if __name__ == "__main__": # logging.run_cli(main) - experiment_config = ExperimentConfig(watch=False) - log_dir = logging.run_main(partial(main, experiment_config, epoch=2)) - # log_dir = + # experiment_config = ExperimentConfig(watch=False) + # log_dir = logging.run_main(partial(main, experiment_config, epoch=2)) + log_dir = "log/Ant-v4/ppo/20240312-114646" eval_experiments(log_dir) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 7607b37..8bff905 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -57,7 +57,12 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=False, + ) experiment = ( REDQExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 0ff6537..1de7dec 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -49,7 +49,12 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=True, + ) experiment = ( PGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 1d71fc7..759e615 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -52,7 +52,12 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=False, + ) experiment = ( SACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 7ea32b8..326223c 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -58,7 +58,12 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=False, + ) experiment = ( TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 2e947b7..6231802 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -58,7 +58,12 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True) + env_factory = MujocoEnvFactory( + task, + train_seed=sampling_config.train_seed, + test_seed=sampling_config.test_seed, + obs_norm=True, + ) experiment = ( TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py index 60d1154..c1d7077 100755 --- a/examples/mujoco/tools.py +++ b/examples/mujoco/tools.py @@ -5,7 +5,7 @@ import csv import os import re from collections import defaultdict -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass import numpy as np import tqdm @@ -34,26 +34,33 @@ class RLiableExperimentResult: test_episode_returns = [] for entry in os.scandir(exp_dir): - if entry.name.startswith('.'): + if entry.name.startswith(".") or not entry.is_dir(): continue exp = Experiment.from_directory(entry.path) - logger = exp.logger_factory.create_logger(entry.path, entry.name, None, asdict(exp.config)) + logger = exp.logger_factory.create_logger( + entry.path, + entry.name, + None, + asdict(exp.config), + ) data = logger.restore_logged_data(entry.path) - test_data = data['test'] + test_data = data["test"] - test_episode_returns.append(test_data['returns_stat']['mean']) - env_step = test_data['env_step'] + test_episode_returns.append(test_data["returns_stat"]["mean"]) + env_step = test_data["env_step"] if score_thresholds is None: score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101) - return RLiableExperimentResult(algorithms=[algo_name], - score_dict={algo_name: np.array(test_episode_returns)}, - env_steps=np.array(env_step), - score_thresholds=score_thresholds, - exp_dir=exp_dir) + return RLiableExperimentResult( + algorithms=[algo_name], + score_dict={algo_name: np.array(test_episode_returns)}, + env_steps=np.array(env_step), + score_thresholds=score_thresholds, + exp_dir=exp_dir, + ) def eval_results(results: RLiableExperimentResult): @@ -64,31 +71,44 @@ def eval_results(results: RLiableExperimentResult): from rliable import plot_utils iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) - iqm_scores, iqm_cis = rly.get_interval_estimates( - results.score_dict, iqm, reps=50000) + iqm_scores, iqm_cis = rly.get_interval_estimates(results.score_dict, iqm, reps=50000) # Plot IQM sample efficiency curve fig, ax = plt.subplots(ncols=1, figsize=(7, 5)) plot_utils.plot_sample_efficiency_curve( - results.env_steps, iqm_scores, iqm_cis, algorithms=results.algorithms, - xlabel=r'Number of env steps', - ylabel='IQM episode return', - ax=ax) - plt.savefig(os.path.join(results.exp_dir, 'iqm_sample_efficiency_curve.png')) + results.env_steps, + iqm_scores, + iqm_cis, + algorithms=results.algorithms, + xlabel=r"Number of env steps", + ylabel="IQM episode return", + ax=ax, + ) + plt.savefig(os.path.join(results.exp_dir, "iqm_sample_efficiency_curve.png")) final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()} score_distributions, score_distributions_cis = rly.create_performance_profile( - final_score_dict, results.score_thresholds) + final_score_dict, + results.score_thresholds, + ) # Plot score distributions fig, ax = plt.subplots(ncols=1, figsize=(7, 5)) plot_utils.plot_performance_profiles( - score_distributions, results.score_thresholds, + score_distributions, + results.score_thresholds, performance_profile_cis=score_distributions_cis, - colors=dict(zip(results.algorithms, sns.color_palette('colorblind'))), - xlabel=r'Episode return $(\tau)$', - ax=ax) - plt.savefig(os.path.join(results.exp_dir, 'performance_profile.png')) + colors=dict( + zip( + results.algorithms, + sns.color_palette("colorblind", n_colors=len(results.algorithms)), + strict=True, + ), + ), + xlabel=r"Episode return $(\tau)$", + ax=ax, + ) + plt.savefig(os.path.join(results.exp_dir, "performance_profile.png")) def find_all_files(root_dir, pattern): diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index c526e2d..ddfce7b 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -6,9 +6,19 @@ from tianshou.highlevel.env import ( class DiscreteTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: - super().__init__(task="CartPole-v0", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY) + super().__init__( + task="CartPole-v0", + train_seed=42, + test_seed=1337, + venv_type=VectorEnvType.DUMMY, + ) class ContinuousTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: - super().__init__(task="Pendulum-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY) + super().__init__( + task="Pendulum-v1", + train_seed=42, + test_seed=1337, + venv_type=VectorEnvType.DUMMY, + ) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 2daa2f3..c5cd7fd 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -3,7 +3,7 @@ import pickle from abc import abstractmethod from collections.abc import Sequence from copy import copy -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from pprint import pformat from typing import Self @@ -87,9 +87,7 @@ from tianshou.utils.string import ToStringMixin log = logging.getLogger(__name__) -def shortener(input_string: str | None = None, - length: int = 1 - ): +def shortener(input_string: str | None = None, length: int = 1) -> str: """Shorten the input string by keeping only the first `length` characters of each word. If the input string is None or empty, return "default". @@ -367,19 +365,19 @@ class ExperimentBuilder: self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() @property - def experiment_config(self): + def experiment_config(self) -> ExperimentConfig: return self._config @experiment_config.setter - def experiment_config(self, experiment_config: ExperimentConfig): + def experiment_config(self, experiment_config: ExperimentConfig) -> None: self._config = experiment_config @property - def sampling_config(self): + def sampling_config(self) -> SamplingConfig: return self._sampling_config @sampling_config.setter - def sampling_config(self, sampling_config: SamplingConfig): + def sampling_config(self, sampling_config: SamplingConfig) -> None: self._sampling_config = sampling_config def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: @@ -491,7 +489,6 @@ class ExperimentBuilder: The keys of the dict are the experiment names, which are derived from the seeds used in the experiments. """ - configured_experiment_config = copy(self.experiment_config) configured_experiment_seed = configured_experiment_config.seed configured_sampling_config = copy(self.sampling_config) @@ -512,9 +509,13 @@ class ExperimentBuilder: self.sampling_config = SamplingConfig(**new_sampling_config_dict) exp = self.build() - full_name = ",".join([f"experiment_seed={exp.config.seed}", - f"train_seed={exp.sampling_config.train_seed}", - f"test_seed={exp.sampling_config.test_seed}"]) + full_name = ",".join( + [ + f"experiment_seed={exp.config.seed}", + f"train_seed={exp.sampling_config.train_seed}", + f"test_seed={exp.sampling_config.test_seed}", + ], + ) experiment_name = shortener(full_name, 4) seeded_experiments[experiment_name] = exp diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 71f0b7c..214a398 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -5,7 +5,6 @@ from typing import Literal, TypeAlias from torch.utils.tensorboard import SummaryWriter from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger -from tianshou.utils.logger.base import LoggerManager from tianshou.utils.logger.pandas_logger import PandasLogger from tianshou.utils.string import ToStringMixin @@ -78,38 +77,3 @@ class LoggerFactoryDefault(LoggerFactory): return TensorboardLogger(writer) case _: raise ValueError(f"Unknown logger type '{self.logger_type}'") - - -class LoggerManagerFactory(LoggerFactory): - def __init__( - self, - logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"], - wandb_project: str | None = None, - ): - self.logger_types = logger_types - self.wandb_project = wandb_project - - self.factories = { - "pandas": LoggerFactoryDefault(logger_type="pandas"), - "wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project), - "tensorboard": LoggerFactoryDefault(logger_type="tensorboard"), - } - - def create_logger( - self, - log_dir: str, - experiment_name: str, - run_id: str | None, - config_dict: dict, - ) -> TLogger: - logger_manager = LoggerManager() - for logger_type in self.logger_types: - logger_manager.loggers.append( - self.factories[logger_type].create_logger( - log_dir, - experiment_name, - run_id, - config_dict, - ) - ) - return logger_manager diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index afe2259..a6a92a0 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from collections.abc import Callable from enum import Enum from numbers import Number -from typing import Any import numpy as np @@ -60,7 +59,7 @@ class BaseLogger(ABC): """ @staticmethod - def prepare_dict_for_logging(log_data: dict) -> dict: + def prepare_dict_for_logging(log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]: return log_data def log_train_data(self, log_data: dict, step: int) -> None: @@ -72,7 +71,7 @@ class BaseLogger(ABC): # TODO: move interval check to calling method if step - self.last_log_train_step >= self.train_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data) + self.write(f"{DataScope.TRAIN.value}/env_step", step, log_data) self.last_log_train_step = step def log_test_data(self, log_data: dict, step: int) -> None: @@ -84,7 +83,7 @@ class BaseLogger(ABC): # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) if step - self.last_log_test_step >= self.test_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data) + self.write(f"{DataScope.TEST.value}/env_step", step, log_data) self.last_log_test_step = step def log_update_data(self, log_data: dict, step: int) -> None: @@ -96,7 +95,7 @@ class BaseLogger(ABC): # TODO: move interval check to calling method if step - self.last_log_update_step >= self.update_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data) + self.write(f"{DataScope.UPDATE.value}/gradient_step", step, log_data) self.last_log_update_step = step def log_info_data(self, log_data: dict, step: int) -> None: @@ -109,7 +108,7 @@ class BaseLogger(ABC): step - self.last_log_info_step >= self.info_interval ): # TODO: move interval check to calling method log_data = self.prepare_dict_for_logging(log_data) - self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data) + self.write(f"{DataScope.INFO.value}/epoch", step, log_data) self.last_log_info_step = step @abstractmethod @@ -139,10 +138,12 @@ class BaseLogger(ABC): :return: epoch, env_step, gradient_step. """ - @staticmethod @abstractmethod - def restore_logged_data(log_path): - """Load the logged data from dist for post-processing.""" + def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]: + """Load the logged data from disk for post-processing. + + :return: a dict containing the logged data. + """ class LazyLogger(BaseLogger): @@ -166,57 +167,5 @@ class LazyLogger(BaseLogger): def restore_data(self) -> tuple[int, int, int]: return 0, 0, 0 - def restore_logged_data(self): - return None - - -class LoggerManager(BaseLogger): - """A container of loggers that holds more than one logger.""" - - def __init__(self) -> None: - super().__init__() - self.loggers = [] - - def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: - for logger in self.loggers: - data_copy = data.copy() - logger.write(step_type, step, data_copy) - - def log_train_data(self, log_data: dict, step: int) -> None: - for logger in self.loggers: - logger.log_train_data(log_data, step) - - def log_test_data(self, log_data: dict, step: int) -> None: - for logger in self.loggers: - logger.log_test_data(log_data, step) - - def log_update_data(self, log_data: dict, step: int) -> None: - for logger in self.loggers: - logger.log_update_data(log_data, step) - - def log_info_data(self, log_data: dict, step: int) -> None: - for logger in self.loggers: - logger.log_info_data(log_data, step) - - def save_data( - self, - epoch: int, - env_step: int, - gradient_step: int, - save_checkpoint_fn: Callable[[int, int, int], str] | None = None, - ) -> None: - for logger in self.loggers: - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - - def restore_data(self) -> tuple[int, int, int]: - for logger in self.loggers: - epoch, env_step, gradient_step = logger.restore_data() - - self.last_save_step = self.last_log_test_step = epoch - self.last_log_update_step = gradient_step - self.last_log_train_step = env_step - - return epoch, env_step, gradient_step - - def restore_logged_data(self, log_path): - return self.loggers[0].restore_logged_data(log_path) + def restore_logged_data(self, log_path: str) -> dict: + return {} diff --git a/tianshou/utils/logger/pandas_logger.py b/tianshou/utils/logger/pandas_logger.py index 882cbf3..f76b5da 100644 --- a/tianshou/utils/logger/pandas_logger.py +++ b/tianshou/utils/logger/pandas_logger.py @@ -1,14 +1,12 @@ import os from collections import defaultdict -from functools import partial -from typing import Callable, Any +from collections.abc import Callable +from typing import Any -import numpy as np import pandas as pd -from matplotlib.figure import Figure from tianshou.utils import BaseLogger, logging -from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE +from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE class PandasLogger(BaseLogger): @@ -21,20 +19,47 @@ class PandasLogger(BaseLogger): info_interval: int = 1, exclude_arrays: bool = True, ) -> None: - super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays) + super().__init__( + train_interval, + test_interval, + update_interval, + info_interval, + exclude_arrays, + ) self.log_path = log_dir self.csv_name = os.path.join(self.log_path, "log.csv") self.pkl_name = os.path.join(self.log_path, "log.pkl") - self.data = defaultdict(list) + self.data: dict[str, list] = defaultdict(list) self.last_save_step = -1 + @staticmethod + def prepare_dict_for_logging(data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]: + """Removes invalid data types from the log data.""" + filtered_dict = data.copy() + + def filter_dict(data_dict: dict[str, Any]) -> None: + """Filter in place.""" + for key, value in data_dict.items(): + if isinstance(value, VALID_LOG_VALS): + filter_dict(value) + else: + filtered_dict.pop(key) + + filter_dict(data) + return filtered_dict + def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: scope, step_name = step_type.split("/") data[step_name] = step self.data[scope].append(data) - def save_data(self, epoch: int, env_step: int, gradient_step: int, - save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None: + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Callable[[int, int, int], str] | None = None, + ) -> None: self.last_save_step = epoch # create and dump a dataframe for k, v in self.data.items(): @@ -45,7 +70,13 @@ class PandasLogger(BaseLogger): def restore_data(self) -> tuple[int, int, int]: for scope in ["train", "test", "update", "info"]: try: - self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values())) + self.data[scope].extend( + list( + pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")) + .T.to_dict() + .values(), + ), + ) except FileNotFoundError: logging.warning(f"Failed to restore {scope} data") @@ -67,12 +98,11 @@ class PandasLogger(BaseLogger): return epoch, env_step, gradient_step - @staticmethod - def restore_logged_data(log_path): + def restore_logged_data(self, log_path: str) -> dict[str, Any]: data = {} - def merge_dicts(list_of_dicts): - result = defaultdict(list) + def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]: + result: dict[str, Any] = defaultdict(list) for d in list_of_dicts: for key, value in d.items(): if isinstance(value, dict): @@ -85,22 +115,10 @@ class PandasLogger(BaseLogger): for scope in ["train", "test", "update", "info"]: try: - dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values()) + dict_list = list( + pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values(), + ) data[scope] = merge_dicts(dict_list) except FileNotFoundError: logging.warning(f"Failed to restore {scope} data") return data - - def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]: - """Filter out matplotlib figures from the data.""" - filtered_dict = data.copy() - - def filter_dict(d): - for key, value in d.items(): - if isinstance(value, dict): - filter_dict(value) - elif isinstance(value, Figure): - filtered_dict.pop(key) - - filter_dict(data) - return filtered_dict diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 555a0c7..8b4f33c 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -1,4 +1,3 @@ -from collections import defaultdict from collections.abc import Callable from typing import Any @@ -7,7 +6,7 @@ from matplotlib.figure import Figure from tensorboard.backend.event_processing import event_accumulator from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS +from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE, BaseLogger from tianshou.utils.warning import deprecation @@ -86,7 +85,7 @@ class TensorboardLogger(BaseLogger): scope, step_name = step_type.split("/") self.writer.add_scalar(step_type, step, global_step=step) for k, v in data.items(): - scope_key = '/'.join([scope, k]) + scope_key = f"{scope}/{k}" if isinstance(v, np.ndarray): self.writer.add_histogram(scope_key, v, global_step=step, bins="auto") elif isinstance(v, Figure): @@ -133,20 +132,19 @@ class TensorboardLogger(BaseLogger): return epoch, env_step, gradient_step - @staticmethod - def restore_logged_data(log_path): + def restore_logged_data(self, log_path: str) -> dict[str, Any]: ea = event_accumulator.EventAccumulator(log_path) ea.Reload() - def add_to_dict(dictionary, keys, value): - current_dict = dictionary + def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None: + current_dict = data_dict for key in keys[:-1]: current_dict = current_dict.setdefault(key, {}) current_dict[keys[-1]] = value - data = {} + data: dict[str, Any] = {} for key in ea.scalars.Keys(): - split_keys = key.split('/') + split_keys = key.split("/") add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)])) return data diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 53dbf10..1ad012a 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -156,3 +156,7 @@ class WandbLogger(BaseLogger): except KeyError: env_step = 0 return epoch, env_step, gradient_step + + def restore_logged_data(self, log_path: str) -> dict: + assert self.tensorboard_logger is not None + return self.tensorboard_logger.restore_logged_data(log_path)