Add experiment result
This commit is contained in:
parent
023b33c917
commit
3bba192633
@ -3,8 +3,8 @@ import pickle
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pprint import pprint
|
from pprint import pformat
|
||||||
from typing import Generic, Self, TypeVar
|
from typing import Any, Generic, Self, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -114,6 +114,12 @@ class ExperimentConfig:
|
|||||||
"""Whether persistence is enabled, allowing files to be stored"""
|
"""Whether persistence is enabled, allowing files to be stored"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExperimentResult:
|
||||||
|
world: World
|
||||||
|
trainer_result: dict[str, Any] | None
|
||||||
|
|
||||||
|
|
||||||
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||||
"""Represents a reinforcement learning experiment.
|
"""Represents a reinforcement learning experiment.
|
||||||
|
|
||||||
@ -172,7 +178,9 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
|
|
||||||
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
def run(
|
||||||
|
self, experiment_name: str | None = None, logger_run_id: str | None = None,
|
||||||
|
) -> ExperimentResult:
|
||||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||||
directory) where all results associated with the experiment will be saved.
|
directory) where all results associated with the experiment will be saved.
|
||||||
The name may contain path separators (os.path.sep, used by os.path.join), in which case
|
The name may contain path separators (os.path.sep, used by os.path.join), in which case
|
||||||
@ -253,11 +261,12 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# train policy
|
# train policy
|
||||||
|
trainer_result: dict[str, Any] | None = None
|
||||||
if self.config.train:
|
if self.config.train:
|
||||||
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
||||||
world.trainer = trainer
|
world.trainer = trainer
|
||||||
trainer_result = trainer.run()
|
trainer_result = trainer.run()
|
||||||
pprint(trainer_result) # TODO logging
|
log.info(f"Trainer result:\n{pformat(trainer_result)}")
|
||||||
|
|
||||||
# watch agent performance
|
# watch agent performance
|
||||||
if self.config.watch:
|
if self.config.watch:
|
||||||
@ -268,7 +277,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
self.config.watch_render,
|
self.config.watch_render,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO return result
|
return ExperimentResult(world=world, trainer_result=trainer_result)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _watch_agent(
|
def _watch_agent(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user