Add experiment result

This commit is contained in:
Dominik Jain 2023-10-13 16:01:11 +02:00
parent 023b33c917
commit 3bba192633

View File

@ -3,8 +3,8 @@ import pickle
from abc import abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from pprint import pprint
from typing import Generic, Self, TypeVar
from pprint import pformat
from typing import Any, Generic, Self, TypeVar
import numpy as np
import torch
@ -114,6 +114,12 @@ class ExperimentConfig:
"""Whether persistence is enabled, allowing files to be stored"""
@dataclass
class ExperimentResult:
world: World
trainer_result: dict[str, Any] | None
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
"""Represents a reinforcement learning experiment.
@ -172,7 +178,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
with open(path, "wb") as f:
pickle.dump(self, f)
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
def run(
self, experiment_name: str | None = None, logger_run_id: str | None = None,
) -> ExperimentResult:
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved.
The name may contain path separators (os.path.sep, used by os.path.join), in which case
@ -253,11 +261,12 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
)
# train policy
trainer_result: dict[str, Any] | None = None
if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer
trainer_result = trainer.run()
pprint(trainer_result) # TODO logging
log.info(f"Trainer result:\n{pformat(trainer_result)}")
# watch agent performance
if self.config.watch:
@ -268,7 +277,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
self.config.watch_render,
)
# TODO return result
return ExperimentResult(world=world, trainer_result=trainer_result)
@staticmethod
def _watch_agent(