# Changes ## Dependencies - New extra "eval" ## Api Extension - `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called - When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. - New method in `ExperimentConfig` called `build_default_seeded_experiments` - `SamplingConfig` has an explicit training seed, `test_seed` is inferred. - New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!). Currently in alpha state. - Loggers can now restore the logged data into python by using the new `restore_logged_data` ## Breaking Changes - `AtariEnvFactory` (in examples) now receives explicit train and test seeds - `EnvFactoryRegistered` now requires an explicit `test_seed` - `BaseLogger.prepare_dict_for_logging` is now abstract --------- Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de> Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de> Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
import os
|
|
from abc import ABC, abstractmethod
|
|
from typing import Literal, TypeAlias
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
|
|
from tianshou.utils.string import ToStringMixin
|
|
|
|
TLogger: TypeAlias = BaseLogger
|
|
|
|
|
|
class LoggerFactory(ToStringMixin, ABC):
|
|
@abstractmethod
|
|
def create_logger(
|
|
self,
|
|
log_dir: str,
|
|
experiment_name: str,
|
|
run_id: str | None,
|
|
config_dict: dict,
|
|
) -> TLogger:
|
|
"""Creates the logger.
|
|
|
|
:param log_dir: path to the directory in which log data is to be stored
|
|
:param experiment_name: the name of the job, which may contain `os.path.sep`
|
|
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
|
|
:param config_dict: a dictionary with data that is to be logged
|
|
:return: the logger
|
|
"""
|
|
|
|
|
|
class LoggerFactoryDefault(LoggerFactory):
|
|
def __init__(
|
|
self,
|
|
logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard",
|
|
wandb_project: str | None = None,
|
|
):
|
|
if logger_type == "wandb" and wandb_project is None:
|
|
raise ValueError("Must provide 'wandb_project'")
|
|
self.logger_type = logger_type
|
|
self.wandb_project = wandb_project
|
|
|
|
def create_logger(
|
|
self,
|
|
log_dir: str,
|
|
experiment_name: str,
|
|
run_id: str | None,
|
|
config_dict: dict,
|
|
) -> TLogger:
|
|
if self.logger_type in ["wandb", "tensorboard"]:
|
|
writer = SummaryWriter(log_dir)
|
|
writer.add_text(
|
|
"args",
|
|
str(
|
|
dict(
|
|
log_dir=log_dir,
|
|
logger_type=self.logger_type,
|
|
wandb_project=self.wandb_project,
|
|
),
|
|
),
|
|
)
|
|
match self.logger_type:
|
|
case "wandb":
|
|
wandb_logger = WandbLogger(
|
|
save_interval=1,
|
|
name=experiment_name.replace(os.path.sep, "__"),
|
|
run_id=run_id,
|
|
config=config_dict,
|
|
project=self.wandb_project,
|
|
)
|
|
wandb_logger.load(writer)
|
|
return wandb_logger
|
|
case "tensorboard":
|
|
return TensorboardLogger(writer)
|
|
case _:
|
|
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|