Improve persistence handling
* Add persistence/restoration of Experiment instance * Add file logging in experiment * Allow all persistence/logging to be disabled * Disable persistence in tests
This commit is contained in:
parent
ba803296cc
commit
76e870207d
examples
test/highlevel
tianshou/highlevel
@ -22,7 +22,7 @@ class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
|||||||
|
|
||||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
policy: DQNPolicy = context.policy
|
policy: DQNPolicy = context.policy
|
||||||
logger = context.logger.logger
|
logger = context.logger
|
||||||
# nature DQN setting, linear decay in the first 1M steps
|
# nature DQN setting, linear decay in the first 1M steps
|
||||||
if env_step <= 1e6:
|
if env_step <= 1e6:
|
||||||
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
|
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
|
||||||
|
@ -81,7 +81,7 @@ def main(
|
|||||||
class TrainEpochCallback(TrainerEpochCallbackTrain):
|
class TrainEpochCallback(TrainerEpochCallbackTrain):
|
||||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
policy: DQNPolicy = context.policy
|
policy: DQNPolicy = context.policy
|
||||||
logger = context.logger.logger
|
logger = context.logger
|
||||||
# nature DQN setting, linear decay in the first 1M steps
|
# nature DQN setting, linear decay in the first 1M steps
|
||||||
if env_step <= 1e6:
|
if env_step <= 1e6:
|
||||||
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
|
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import warnings
|
import warnings
|
||||||
import logging
|
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||||
from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent
|
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
||||||
from tianshou.highlevel.world import World
|
from tianshou.highlevel.world import World
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -33,7 +33,7 @@ from tianshou.highlevel.experiment import (
|
|||||||
def test_experiment_builder_continuous_default_params(builder_cls):
|
def test_experiment_builder_continuous_default_params(builder_cls):
|
||||||
env_factory = ContinuousTestEnvFactory()
|
env_factory = ContinuousTestEnvFactory()
|
||||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||||
experiment_config = ExperimentConfig()
|
experiment_config = ExperimentConfig(persistence_enabled=False)
|
||||||
builder = builder_cls(
|
builder = builder_cls(
|
||||||
experiment_config=experiment_config,
|
experiment_config=experiment_config,
|
||||||
env_factory=env_factory,
|
env_factory=env_factory,
|
||||||
|
@ -11,7 +11,6 @@ from tianshou.highlevel.experiment import (
|
|||||||
IQNExperimentBuilder,
|
IQNExperimentBuilder,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
)
|
)
|
||||||
from tianshou.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -25,11 +24,10 @@ from tianshou.utils import logging
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||||
logging.configure()
|
|
||||||
env_factory = DiscreteTestEnvFactory()
|
env_factory = DiscreteTestEnvFactory()
|
||||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||||
builder = builder_cls(
|
builder = builder_cls(
|
||||||
experiment_config=ExperimentConfig(),
|
experiment_config=ExperimentConfig(persistence_enabled=False),
|
||||||
env_factory=env_factory,
|
env_factory=env_factory,
|
||||||
sampling_config=sampling_config,
|
sampling_config=sampling_config,
|
||||||
)
|
)
|
||||||
|
@ -171,7 +171,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
batch_size=sampling_config.batch_size,
|
batch_size=sampling_config.batch_size,
|
||||||
step_per_collect=sampling_config.step_per_collect,
|
step_per_collect=sampling_config.step_per_collect,
|
||||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||||
logger=world.logger.logger,
|
logger=world.logger,
|
||||||
test_in_train=False,
|
test_in_train=False,
|
||||||
train_fn=train_fn,
|
train_fn=train_fn,
|
||||||
test_fn=test_fn,
|
test_fn=test_fn,
|
||||||
@ -211,7 +211,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
episode_per_test=sampling_config.num_test_envs,
|
episode_per_test=sampling_config.num_test_envs,
|
||||||
batch_size=sampling_config.batch_size,
|
batch_size=sampling_config.batch_size,
|
||||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||||
logger=world.logger.logger,
|
logger=world.logger,
|
||||||
update_per_step=sampling_config.update_per_step,
|
update_per_step=sampling_config.update_per_step,
|
||||||
test_in_train=False,
|
test_in_train=False,
|
||||||
train_fn=train_fn,
|
train_fn=train_fn,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import pickle
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -65,7 +65,11 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
TRPOParams,
|
TRPOParams,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup
|
from tianshou.highlevel.persistence import (
|
||||||
|
PersistableConfigProtocol,
|
||||||
|
PersistenceGroup,
|
||||||
|
PolicyPersistence,
|
||||||
|
)
|
||||||
from tianshou.highlevel.trainer import (
|
from tianshou.highlevel.trainer import (
|
||||||
TrainerCallbacks,
|
TrainerCallbacks,
|
||||||
TrainerEpochCallbackTest,
|
TrainerEpochCallbackTest,
|
||||||
@ -75,6 +79,7 @@ from tianshou.highlevel.trainer import (
|
|||||||
from tianshou.highlevel.world import World
|
from tianshou.highlevel.world import World
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
|
from tianshou.utils import LazyLogger, logging
|
||||||
from tianshou.utils.logging import datetime_tag
|
from tianshou.utils.logging import datetime_tag
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
@ -93,7 +98,7 @@ class ExperimentConfig:
|
|||||||
device: TDevice = "cuda" if torch.cuda.is_available() else "cpu"
|
device: TDevice = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
"""The torch device to use"""
|
"""The torch device to use"""
|
||||||
policy_restore_directory: str | None = None
|
policy_restore_directory: str | None = None
|
||||||
"""Directory from which to load the policy neural network parameters (saved in a previous run)"""
|
"""Directory from which to load the policy neural network parameters (persistence directory of a previous run)"""
|
||||||
train: bool = True
|
train: bool = True
|
||||||
"""Whether to perform training"""
|
"""Whether to perform training"""
|
||||||
watch: bool = True
|
watch: bool = True
|
||||||
@ -102,9 +107,24 @@ class ExperimentConfig:
|
|||||||
"""Number of episodes for which to watch performance (if watch is enabled)"""
|
"""Number of episodes for which to watch performance (if watch is enabled)"""
|
||||||
watch_render: float = 0.0
|
watch_render: float = 0.0
|
||||||
"""Milliseconds between rendered frames when watching agent performance (if watch is enabled)"""
|
"""Milliseconds between rendered frames when watching agent performance (if watch is enabled)"""
|
||||||
|
persistence_base_dir: str = "log"
|
||||||
|
"""Base directory in which experiment data is to be stored. Every experiment run will create a subdirectory
|
||||||
|
in this directory based on the run's experiment name"""
|
||||||
|
persistence_enabled: bool = True
|
||||||
|
"""Whether persistence is enabled, allowing files to be stored"""
|
||||||
|
|
||||||
|
|
||||||
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||||
|
"""Represents a reinforcement learning experiment.
|
||||||
|
|
||||||
|
An experiment is composed only of configuration and factory objects, which themselves
|
||||||
|
should be designed to contain only configuration. Therefore, experiments can easily
|
||||||
|
be stored/pickled and later restored without any problems.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOG_FILENAME = "log.txt"
|
||||||
|
EXPERIMENT_PICKLE_FILENAME = "experiment.pkl"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ExperimentConfig,
|
config: ExperimentConfig,
|
||||||
@ -121,15 +141,31 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
self.logger_factory = logger_factory
|
self.logger_factory = logger_factory
|
||||||
self.env_config = env_config
|
self.env_config = env_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_directory(cls, directory: str) -> Self:
|
||||||
|
"""Restores an experiment from a previously stored pickle.
|
||||||
|
|
||||||
|
:param directory: persistence directory of a previous run, in which a pickled experiment is found
|
||||||
|
"""
|
||||||
|
with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
def _set_seed(self) -> None:
|
def _set_seed(self) -> None:
|
||||||
seed = self.config.seed
|
seed = self.config.seed
|
||||||
|
log.info(f"Setting random seed {seed}")
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
def _build_config_dict(self) -> dict:
|
def _build_config_dict(self) -> dict:
|
||||||
return {
|
return {"experiment": self.pprints()}
|
||||||
"experiment": self.pprints()
|
|
||||||
}
|
def save(self, directory: str):
|
||||||
|
path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME)
|
||||||
|
log.info(
|
||||||
|
f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')",
|
||||||
|
)
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
|
||||||
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
||||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||||
@ -144,6 +180,17 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
if experiment_name is None:
|
if experiment_name is None:
|
||||||
experiment_name = datetime_tag()
|
experiment_name = datetime_tag()
|
||||||
|
|
||||||
|
# initialize persistence directory
|
||||||
|
use_persistence = self.config.persistence_enabled
|
||||||
|
persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name)
|
||||||
|
if use_persistence:
|
||||||
|
os.makedirs(persistence_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with logging.FileLoggerContext(
|
||||||
|
os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence,
|
||||||
|
):
|
||||||
|
# log initial information
|
||||||
|
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
|
||||||
log.info(f"Working directory: {os.getcwd()}")
|
log.info(f"Working directory: {os.getcwd()}")
|
||||||
|
|
||||||
self._set_seed()
|
self._set_seed()
|
||||||
@ -153,34 +200,44 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
log.info(f"Created {envs}")
|
log.info(f"Created {envs}")
|
||||||
|
|
||||||
# initialize persistence
|
# initialize persistence
|
||||||
additional_persistence = PersistenceGroup(*envs.persistence)
|
additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence)
|
||||||
policy_persistence = PolicyPersistence(additional_persistence)
|
policy_persistence = PolicyPersistence(additional_persistence, enabled=use_persistence)
|
||||||
|
if use_persistence:
|
||||||
|
log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}")
|
||||||
|
self.save(persistence_dir)
|
||||||
|
|
||||||
# initialize logger
|
# initialize logger
|
||||||
full_config = self._build_config_dict()
|
full_config = self._build_config_dict()
|
||||||
full_config.update(envs.info())
|
full_config.update(envs.info())
|
||||||
|
if use_persistence:
|
||||||
logger = self.logger_factory.create_logger(
|
logger = self.logger_factory.create_logger(
|
||||||
log_name=experiment_name,
|
log_dir=persistence_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
run_id=logger_run_id,
|
run_id=logger_run_id,
|
||||||
config_dict=full_config,
|
config_dict=full_config,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger = LazyLogger()
|
||||||
|
|
||||||
|
# create policy and collectors
|
||||||
policy = self.agent_factory.create_policy(envs, self.config.device)
|
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||||
|
|
||||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
||||||
policy,
|
policy,
|
||||||
envs,
|
envs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# create context object with all relevant instances (except trainer; added later)
|
||||||
world = World(
|
world = World(
|
||||||
envs=envs,
|
envs=envs,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
train_collector=train_collector,
|
train_collector=train_collector,
|
||||||
test_collector=test_collector,
|
test_collector=test_collector,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
restore_directory=self.config.policy_restore_directory
|
persist_directory=persistence_dir,
|
||||||
|
restore_directory=self.config.policy_restore_directory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# restore policy parameters if applicable
|
||||||
if self.config.policy_restore_directory:
|
if self.config.policy_restore_directory:
|
||||||
policy_persistence.restore(
|
policy_persistence.restore(
|
||||||
policy,
|
policy,
|
||||||
@ -188,13 +245,14 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
self.config.device,
|
self.config.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# train policy
|
||||||
if self.config.train:
|
if self.config.train:
|
||||||
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
||||||
world.trainer = trainer
|
world.trainer = trainer
|
||||||
|
trainer_result = trainer.run()
|
||||||
|
pprint(trainer_result) # TODO logging
|
||||||
|
|
||||||
result = trainer.run()
|
# watch agent performance
|
||||||
pprint(result) # TODO logging
|
|
||||||
|
|
||||||
if self.config.watch:
|
if self.config.watch:
|
||||||
self._watch_agent(
|
self._watch_agent(
|
||||||
self.config.watch_num_episodes,
|
self.config.watch_num_episodes,
|
||||||
@ -304,7 +362,6 @@ class ExperimentBuilder:
|
|||||||
self._logger_factory,
|
self._logger_factory,
|
||||||
env_config=self._env_config,
|
env_config=self._env_config,
|
||||||
)
|
)
|
||||||
log.info(f"Created experiment:\n{experiment.pprints()}")
|
|
||||||
return experiment
|
return experiment
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Literal, TypeAlias
|
from typing import Literal, TypeAlias
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -11,38 +10,39 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Logger:
|
|
||||||
logger: TLogger
|
|
||||||
log_path: str
|
|
||||||
|
|
||||||
|
|
||||||
class LoggerFactory(ToStringMixin, ABC):
|
class LoggerFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
def create_logger(
|
||||||
pass
|
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
|
||||||
|
) -> TLogger:
|
||||||
|
""":param log_dir: path to the directory in which log data is to be stored
|
||||||
|
:param experiment_name: the name of the job, which may contain os.path.sep
|
||||||
|
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
|
||||||
|
:param config_dict: a dictionary with data that is to be logged
|
||||||
|
:return: the logger
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class DefaultLoggerFactory(LoggerFactory):
|
class DefaultLoggerFactory(LoggerFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
log_dir: str = "log",
|
|
||||||
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
|
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
|
||||||
wandb_project: str | None = None,
|
wandb_project: str | None = None,
|
||||||
):
|
):
|
||||||
if logger_type == "wandb" and wandb_project is None:
|
if logger_type == "wandb" and wandb_project is None:
|
||||||
raise ValueError("Must provide 'wandb_project'")
|
raise ValueError("Must provide 'wandb_project'")
|
||||||
self.log_dir = log_dir
|
|
||||||
self.logger_type = logger_type
|
self.logger_type = logger_type
|
||||||
self.wandb_project = wandb_project
|
self.wandb_project = wandb_project
|
||||||
|
|
||||||
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
def create_logger(
|
||||||
writer = SummaryWriter(self.log_dir)
|
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
|
||||||
|
) -> TLogger:
|
||||||
|
writer = SummaryWriter(log_dir)
|
||||||
writer.add_text(
|
writer.add_text(
|
||||||
"args",
|
"args",
|
||||||
str(
|
str(
|
||||||
dict(
|
dict(
|
||||||
log_dir=self.log_dir,
|
log_dir=log_dir,
|
||||||
logger_type=self.logger_type,
|
logger_type=self.logger_type,
|
||||||
wandb_project=self.wandb_project,
|
wandb_project=self.wandb_project,
|
||||||
),
|
),
|
||||||
@ -52,7 +52,7 @@ class DefaultLoggerFactory(LoggerFactory):
|
|||||||
if self.logger_type == "wandb":
|
if self.logger_type == "wandb":
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
save_interval=1,
|
save_interval=1,
|
||||||
name=log_name.replace(os.path.sep, "__"),
|
name=experiment_name.replace(os.path.sep, "__"),
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
config=config_dict,
|
config=config_dict,
|
||||||
project=self.wandb_project,
|
project=self.wandb_project,
|
||||||
@ -62,6 +62,4 @@ class DefaultLoggerFactory(LoggerFactory):
|
|||||||
logger = TensorboardLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||||
log_path = os.path.join(self.log_dir, log_name)
|
return logger
|
||||||
os.makedirs(log_path, exist_ok=True)
|
|
||||||
return Logger(logger=logger, log_path=log_path)
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable
|
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -25,11 +26,15 @@ class PersistableConfigProtocol(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class PersistEvent(Enum):
|
class PersistEvent(Enum):
|
||||||
|
"""Enumeration of persistence events that Persistence objects can react to."""
|
||||||
|
|
||||||
PERSIST_POLICY = "persist_policy"
|
PERSIST_POLICY = "persist_policy"
|
||||||
"""Policy neural network is persisted (new best found)"""
|
"""Policy neural network is persisted (new best found)"""
|
||||||
|
|
||||||
|
|
||||||
class RestoreEvent(Enum):
|
class RestoreEvent(Enum):
|
||||||
|
"""Enumeration of restoration events that Persistence objects can react to."""
|
||||||
|
|
||||||
RESTORE_POLICY = "restore_policy"
|
RESTORE_POLICY = "restore_policy"
|
||||||
"""Policy neural network parameters are restored"""
|
"""Policy neural network parameters are restored"""
|
||||||
|
|
||||||
@ -45,10 +50,13 @@ class Persistence(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class PersistenceGroup(Persistence):
|
class PersistenceGroup(Persistence):
|
||||||
def __init__(self, *p: Persistence):
|
def __init__(self, *p: Persistence, enabled=True):
|
||||||
self.items = p
|
self.items = p
|
||||||
|
self.enabled = enabled
|
||||||
|
|
||||||
def persist(self, event: PersistEvent, world: World) -> None:
|
def persist(self, event: PersistEvent, world: World) -> None:
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
for item in self.items:
|
for item in self.items:
|
||||||
item.persist(event, world)
|
item.persist(event, world)
|
||||||
|
|
||||||
@ -60,10 +68,17 @@ class PersistenceGroup(Persistence):
|
|||||||
class PolicyPersistence:
|
class PolicyPersistence:
|
||||||
FILENAME = "policy.dat"
|
FILENAME = "policy.dat"
|
||||||
|
|
||||||
def __init__(self, additional_persistence: Persistence | None = None):
|
def __init__(self, additional_persistence: Persistence | None = None, enabled=True):
|
||||||
|
""":param additional_persistence: a persistence instance which is to be envoked whenever
|
||||||
|
this object is used to persist/restore data
|
||||||
|
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||||
|
"""
|
||||||
self.additional_persistence = additional_persistence
|
self.additional_persistence = additional_persistence
|
||||||
|
self.enabled = enabled
|
||||||
|
|
||||||
def persist(self, policy: torch.nn.Module, world: World) -> None:
|
def persist(self, policy: torch.nn.Module, world: World) -> None:
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
path = world.persist_path(self.FILENAME)
|
path = world.persist_path(self.FILENAME)
|
||||||
log.info(f"Saving policy in {path}")
|
log.info(f"Saving policy in {path}")
|
||||||
torch.save(policy.state_dict(), path)
|
torch.save(policy.state_dict(), path)
|
||||||
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import TLogger
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
|||||||
|
|
||||||
|
|
||||||
class TrainingContext:
|
class TrainingContext:
|
||||||
def __init__(self, policy: TPolicy, envs: Environments, logger: Logger):
|
def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger):
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.envs = envs
|
self.envs = envs
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import TLogger
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
|
|
||||||
@ -16,15 +16,11 @@ class World:
|
|||||||
policy: "BasePolicy"
|
policy: "BasePolicy"
|
||||||
train_collector: "Collector"
|
train_collector: "Collector"
|
||||||
test_collector: "Collector"
|
test_collector: "Collector"
|
||||||
logger: "Logger"
|
logger: "TLogger"
|
||||||
|
persist_directory: str
|
||||||
restore_directory: str
|
restore_directory: str
|
||||||
trainer: Optional["BaseTrainer"] = None
|
trainer: Optional["BaseTrainer"] = None
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def persist_directory(self):
|
|
||||||
return self.logger.log_path
|
|
||||||
|
|
||||||
def persist_path(self, filename: str) -> str:
|
def persist_path(self, filename: str) -> str:
|
||||||
return os.path.join(self.persist_directory, filename)
|
return os.path.join(self.persist_directory, filename)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user