From adc324038ad2b2381672efd64bd5e0f2b812df44 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 20 Sep 2023 15:10:19 +0200 Subject: [PATCH] Remove LoggerConfig --- examples/mujoco/mujoco_ppo_hl.py | 5 ++--- examples/mujoco/mujoco_sac_hl.py | 5 ++--- tianshou/highlevel/logger.py | 32 +++++++++++++------------------- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 570dc61..71a0b6d 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -15,7 +15,7 @@ from tianshou.highlevel.experiment import ( RLExperimentConfig, RLSamplingConfig, ) -from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig +from tianshou.highlevel.logger import DefaultLoggerFactory from tianshou.highlevel.module import ( ContinuousActorProbFactory, ContinuousNetCriticFactory, @@ -32,7 +32,6 @@ class NNConfig: def main( experiment_config: RLExperimentConfig, - logger_config: LoggerConfig, sampling_config: RLSamplingConfig, general_config: RLAgentConfig, pg_config: PGConfig, @@ -42,7 +41,7 @@ def main( ): now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) - logger_factory = DefaultLoggerFactory(logger_config) + logger_factory = DefaultLoggerFactory() env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 0d9ef4b..cb93133 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -13,7 +13,7 @@ from tianshou.highlevel.experiment import ( RLExperimentConfig, RLSamplingConfig, ) -from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig +from tianshou.highlevel.logger import DefaultLoggerFactory from tianshou.highlevel.module import ( ContinuousActorProbFactory, ContinuousNetCriticFactory, @@ -23,7 +23,6 @@ from tianshou.highlevel.optim import AdamOptimizerFactory def main( experiment_config: RLExperimentConfig, - logger_config: LoggerConfig, sampling_config: RLSamplingConfig, sac_config: SACConfig, hidden_sizes: Sequence[int] = (256, 256), @@ -31,7 +30,7 @@ def main( ): now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "sac", str(experiment_config.seed), now) - logger_factory = DefaultLoggerFactory(logger_config) + logger_factory = DefaultLoggerFactory() env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 06bd195..b9e2fab 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -22,36 +22,30 @@ class LoggerFactory(ABC): pass -@dataclass -class LoggerConfig: - """Logging config.""" - - logdir: str = "log" - logger: Literal["tensorboard", "wandb"] = "tensorboard" - wandb_project: str = "mujoco.benchmark" - """Only used if logger is wandb.""" - - class DefaultLoggerFactory(LoggerFactory): - def __init__(self, config: LoggerConfig): - self.config = config + def __init__(self, log_dir: str = "log", logger_type: Literal["tensorboard", "wandb"] = "tensorboard", wandb_project: str | None = None): + if logger_type == "wandb" and wandb_project is None: + raise ValueError("Must provide 'wand_project'") + self.log_dir = log_dir + self.logger_type = logger_type + self.wandb_project = wandb_project def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger: - writer = SummaryWriter(self.config.logdir) - writer.add_text("args", str(self.config)) - if self.config.logger == "wandb": + writer = SummaryWriter(self.log_dir) + writer.add_text("args", str(dict(log_dir=self.log_dir, logger_type=self.logger_type, wandb_project=self.wandb_project))) + if self.logger_type == "wandb": logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=run_id, config=config_dict, - project=self.config.wandb_project, + project=self.wandb_project, ) logger.load(writer) - elif self.config.logger == "tensorboard": + elif self.logger_type == "tensorboard": logger = TensorboardLogger(writer) else: - raise ValueError(f"Unknown logger: {self.config.logger}") - log_path = os.path.join(self.config.logdir, log_name) + raise ValueError(f"Unknown logger type '{self.logger_type}'") + log_path = os.path.join(self.log_dir, log_name) os.makedirs(log_path, exist_ok=True) return Logger(logger=logger, log_path=log_path)