Add experiment result

This commit is contained in:
Dominik Jain 2023-10-13 16:01:11 +02:00
parent 023b33c917
commit 3bba192633

View File

@ -3,8 +3,8 @@ import pickle
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from pprint import pprint from pprint import pformat
from typing import Generic, Self, TypeVar from typing import Any, Generic, Self, TypeVar
import numpy as np import numpy as np
import torch import torch
@ -114,6 +114,12 @@ class ExperimentConfig:
"""Whether persistence is enabled, allowing files to be stored""" """Whether persistence is enabled, allowing files to be stored"""
@dataclass
class ExperimentResult:
world: World
trainer_result: dict[str, Any] | None
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
"""Represents a reinforcement learning experiment. """Represents a reinforcement learning experiment.
@ -172,7 +178,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
with open(path, "wb") as f: with open(path, "wb") as f:
pickle.dump(self, f) pickle.dump(self, f)
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None: def run(
self, experiment_name: str | None = None, logger_run_id: str | None = None,
) -> ExperimentResult:
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging """:param experiment_name: the experiment name, which corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved. directory) where all results associated with the experiment will be saved.
The name may contain path separators (os.path.sep, used by os.path.join), in which case The name may contain path separators (os.path.sep, used by os.path.join), in which case
@ -253,11 +261,12 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
) )
# train policy # train policy
trainer_result: dict[str, Any] | None = None
if self.config.train: if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence) trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer world.trainer = trainer
trainer_result = trainer.run() trainer_result = trainer.run()
pprint(trainer_result) # TODO logging log.info(f"Trainer result:\n{pformat(trainer_result)}")
# watch agent performance # watch agent performance
if self.config.watch: if self.config.watch:
@ -268,7 +277,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
self.config.watch_render, self.config.watch_render,
) )
# TODO return result return ExperimentResult(world=world, trainer_result=trainer_result)
@staticmethod @staticmethod
def _watch_agent( def _watch_agent(