maxhuettenrauch ade85ab32b
Feature/algo eval (#1074)
# Changes

## Dependencies

- New extra "eval"

## Api Extension
- `Experiment` and `ExperimentConfig` now have a `name`, that can
however be overridden when `Experiment.run()` is called
- When building an `Experiment` from an `ExperimentConfig`, the user has
the option to add info about seeds to the name.
- New method in `ExperimentConfig` called
`build_default_seeded_experiments`
- `SamplingConfig` has an explicit training seed, `test_seed` is
inferred.
- New `evaluation` package for repeating the same experiment with
multiple seeds and aggregating the results (important extension!).
Currently in alpha state.
- Loggers can now restore the logged data into python by using the new
`restore_logged_data`

## Breaking Changes
- `AtariEnvFactory` (in examples) now receives explicit train and test
seeds
- `EnvFactoryRegistered` now requires an explicit `test_seed`
- `BaseLogger.prepare_dict_for_logging` is now abstract

---------

Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
2024-04-20 23:25:33 +00:00

35 lines
1.1 KiB
Python

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tianshou.data import Collector
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
@dataclass
class World:
"""Container for instances and configuration items that are relevant to an experiment."""
envs: "Environments"
policy: "BasePolicy"
train_collector: "Collector"
test_collector: "Collector"
logger: "TLogger"
persist_directory: str
restore_directory: str | None
trainer: Optional["BaseTrainer"] = None
def persist_path(self, filename: str) -> str:
return os.path.abspath(os.path.join(self.persist_directory, filename))
def restore_path(self, filename: str) -> str:
if self.restore_directory is None:
raise ValueError(
"Path cannot be formed because no directory for restoration was provided",
)
return os.path.join(self.restore_directory, filename)