From 3bba19263354bebd477a1a63bb46054a068a88e4 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 13 Oct 2023 16:01:11 +0200 Subject: [PATCH] Add experiment result --- tianshou/highlevel/experiment.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index a5c72f0..ed2575d 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -3,8 +3,8 @@ import pickle from abc import abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass -from pprint import pprint -from typing import Generic, Self, TypeVar +from pprint import pformat +from typing import Any, Generic, Self, TypeVar import numpy as np import torch @@ -114,6 +114,12 @@ class ExperimentConfig: """Whether persistence is enabled, allowing files to be stored""" +@dataclass +class ExperimentResult: + world: World + trainer_result: dict[str, Any] | None + + class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): """Represents a reinforcement learning experiment. @@ -172,7 +178,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): with open(path, "wb") as f: pickle.dump(self, f) - def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None: + def run( + self, experiment_name: str | None = None, logger_run_id: str | None = None, + ) -> ExperimentResult: """:param experiment_name: the experiment name, which corresponds to the directory (within the logging directory) where all results associated with the experiment will be saved. The name may contain path separators (os.path.sep, used by os.path.join), in which case @@ -253,11 +261,12 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): ) # train policy + trainer_result: dict[str, Any] | None = None if self.config.train: trainer = self.agent_factory.create_trainer(world, policy_persistence) world.trainer = trainer trainer_result = trainer.run() - pprint(trainer_result) # TODO logging + log.info(f"Trainer result:\n{pformat(trainer_result)}") # watch agent performance if self.config.watch: @@ -268,7 +277,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): self.config.watch_render, ) - # TODO return result + return ExperimentResult(world=world, trainer_result=trainer_result) @staticmethod def _watch_agent(