Add experiment result
This commit is contained in:
parent
023b33c917
commit
3bba192633
@ -3,8 +3,8 @@ import pickle
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from pprint import pprint
|
||||
from typing import Generic, Self, TypeVar
|
||||
from pprint import pformat
|
||||
from typing import Any, Generic, Self, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -114,6 +114,12 @@ class ExperimentConfig:
|
||||
"""Whether persistence is enabled, allowing files to be stored"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentResult:
|
||||
world: World
|
||||
trainer_result: dict[str, Any] | None
|
||||
|
||||
|
||||
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
"""Represents a reinforcement learning experiment.
|
||||
|
||||
@ -172,7 +178,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
||||
def run(
|
||||
self, experiment_name: str | None = None, logger_run_id: str | None = None,
|
||||
) -> ExperimentResult:
|
||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||
directory) where all results associated with the experiment will be saved.
|
||||
The name may contain path separators (os.path.sep, used by os.path.join), in which case
|
||||
@ -253,11 +261,12 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
)
|
||||
|
||||
# train policy
|
||||
trainer_result: dict[str, Any] | None = None
|
||||
if self.config.train:
|
||||
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
||||
world.trainer = trainer
|
||||
trainer_result = trainer.run()
|
||||
pprint(trainer_result) # TODO logging
|
||||
log.info(f"Trainer result:\n{pformat(trainer_result)}")
|
||||
|
||||
# watch agent performance
|
||||
if self.config.watch:
|
||||
@ -268,7 +277,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
self.config.watch_render,
|
||||
)
|
||||
|
||||
# TODO return result
|
||||
return ExperimentResult(world=world, trainer_result=trainer_result)
|
||||
|
||||
@staticmethod
|
||||
def _watch_agent(
|
||||
|
Loading…
x
Reference in New Issue
Block a user