maxhuettenrauch ade85ab32b
Feature/algo eval (#1074)
# Changes

## Dependencies

- New extra "eval"

## Api Extension
- `Experiment` and `ExperimentConfig` now have a `name`, that can
however be overridden when `Experiment.run()` is called
- When building an `Experiment` from an `ExperimentConfig`, the user has
the option to add info about seeds to the name.
- New method in `ExperimentConfig` called
`build_default_seeded_experiments`
- `SamplingConfig` has an explicit training seed, `test_seed` is
inferred.
- New `evaluation` package for repeating the same experiment with
multiple seeds and aggregating the results (important extension!).
Currently in alpha state.
- Loggers can now restore the logged data into python by using the new
`restore_logged_data`

## Breaking Changes
- `AtariEnvFactory` (in examples) now receives explicit train and test
seeds
- `EnvFactoryRegistered` now requires an explicit `test_seed`
- `BaseLogger.prepare_dict_for_logging` is now abstract

---------

Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
2024-04-20 23:25:33 +00:00

77 lines
2.5 KiB
Python

import os
from abc import ABC, abstractmethod
from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = BaseLogger
class LoggerFactory(ToStringMixin, ABC):
@abstractmethod
def create_logger(
self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger:
"""Creates the logger.
:param log_dir: path to the directory in which log data is to be stored
:param experiment_name: the name of the job, which may contain `os.path.sep`
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
:param config_dict: a dictionary with data that is to be logged
:return: the logger
"""
class LoggerFactoryDefault(LoggerFactory):
def __init__(
self,
logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard",
wandb_project: str | None = None,
):
if logger_type == "wandb" and wandb_project is None:
raise ValueError("Must provide 'wandb_project'")
self.logger_type = logger_type
self.wandb_project = wandb_project
def create_logger(
self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger:
if self.logger_type in ["wandb", "tensorboard"]:
writer = SummaryWriter(log_dir)
writer.add_text(
"args",
str(
dict(
log_dir=log_dir,
logger_type=self.logger_type,
wandb_project=self.wandb_project,
),
),
)
match self.logger_type:
case "wandb":
wandb_logger = WandbLogger(
save_interval=1,
name=experiment_name.replace(os.path.sep, "__"),
run_id=run_id,
config=config_dict,
project=self.wandb_project,
)
wandb_logger.load(writer)
return wandb_logger
case "tensorboard":
return TensorboardLogger(writer)
case _:
raise ValueError(f"Unknown logger type '{self.logger_type}'")