Improve persistence handling

* Add persistence/restoration of Experiment instance
* Add file logging in experiment
* Allow all persistence/logging to be disabled
* Disable persistence in tests
This commit is contained in:
Dominik Jain 2023-10-12 17:40:16 +02:00
parent ba803296cc
commit 76e870207d
11 changed files with 160 additions and 96 deletions

View File

@ -22,7 +22,7 @@ class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger.logger
logger = context.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)

View File

@ -81,7 +81,7 @@ def main(
class TrainEpochCallback(TrainerEpochCallbackTrain):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger.logger
logger = context.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)

View File

@ -1,13 +1,13 @@
import logging
import pickle
import warnings
import logging
import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World
try:

View File

@ -33,7 +33,7 @@ from tianshou.highlevel.experiment import (
def test_experiment_builder_continuous_default_params(builder_cls):
env_factory = ContinuousTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
experiment_config = ExperimentConfig()
experiment_config = ExperimentConfig(persistence_enabled=False)
builder = builder_cls(
experiment_config=experiment_config,
env_factory=env_factory,

View File

@ -11,7 +11,6 @@ from tianshou.highlevel.experiment import (
IQNExperimentBuilder,
PPOExperimentBuilder,
)
from tianshou.utils import logging
@pytest.mark.parametrize(
@ -25,11 +24,10 @@ from tianshou.utils import logging
],
)
def test_experiment_builder_discrete_default_params(builder_cls):
logging.configure()
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
builder = builder_cls(
experiment_config=ExperimentConfig(),
experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory,
sampling_config=sampling_config,
)

View File

@ -171,7 +171,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger.logger,
logger=world.logger,
test_in_train=False,
train_fn=train_fn,
test_fn=test_fn,
@ -211,7 +211,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger.logger,
logger=world.logger,
update_per_step=sampling_config.update_per_step,
test_in_train=False,
train_fn=train_fn,

View File

@ -1,5 +1,5 @@
import os
import logging
import pickle
from abc import abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
@ -65,7 +65,11 @@ from tianshou.highlevel.params.policy_params import (
TRPOParams,
)
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup
from tianshou.highlevel.persistence import (
PersistableConfigProtocol,
PersistenceGroup,
PolicyPersistence,
)
from tianshou.highlevel.trainer import (
TrainerCallbacks,
TrainerEpochCallbackTest,
@ -75,6 +79,7 @@ from tianshou.highlevel.trainer import (
from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
from tianshou.utils import LazyLogger, logging
from tianshou.utils.logging import datetime_tag
from tianshou.utils.string import ToStringMixin
@ -93,7 +98,7 @@ class ExperimentConfig:
device: TDevice = "cuda" if torch.cuda.is_available() else "cpu"
"""The torch device to use"""
policy_restore_directory: str | None = None
"""Directory from which to load the policy neural network parameters (saved in a previous run)"""
"""Directory from which to load the policy neural network parameters (persistence directory of a previous run)"""
train: bool = True
"""Whether to perform training"""
watch: bool = True
@ -102,9 +107,24 @@ class ExperimentConfig:
"""Number of episodes for which to watch performance (if watch is enabled)"""
watch_render: float = 0.0
"""Milliseconds between rendered frames when watching agent performance (if watch is enabled)"""
persistence_base_dir: str = "log"
"""Base directory in which experiment data is to be stored. Every experiment run will create a subdirectory
in this directory based on the run's experiment name"""
persistence_enabled: bool = True
"""Whether persistence is enabled, allowing files to be stored"""
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
"""Represents a reinforcement learning experiment.
An experiment is composed only of configuration and factory objects, which themselves
should be designed to contain only configuration. Therefore, experiments can easily
be stored/pickled and later restored without any problems.
"""
LOG_FILENAME = "log.txt"
EXPERIMENT_PICKLE_FILENAME = "experiment.pkl"
def __init__(
self,
config: ExperimentConfig,
@ -121,15 +141,31 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
self.logger_factory = logger_factory
self.env_config = env_config
@classmethod
def from_directory(cls, directory: str) -> Self:
"""Restores an experiment from a previously stored pickle.
:param directory: persistence directory of a previous run, in which a pickled experiment is found
"""
with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f:
return pickle.load(f)
def _set_seed(self) -> None:
seed = self.config.seed
log.info(f"Setting random seed {seed}")
np.random.seed(seed)
torch.manual_seed(seed)
def _build_config_dict(self) -> dict:
return {
"experiment": self.pprints()
}
return {"experiment": self.pprints()}
def save(self, directory: str):
path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME)
log.info(
f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')",
)
with open(path, "wb") as f:
pickle.dump(self, f)
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
@ -144,66 +180,88 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
if experiment_name is None:
experiment_name = datetime_tag()
log.info(f"Working directory: {os.getcwd()}")
# initialize persistence directory
use_persistence = self.config.persistence_enabled
persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name)
if use_persistence:
os.makedirs(persistence_dir, exist_ok=True)
self._set_seed()
with logging.FileLoggerContext(
os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence,
):
# log initial information
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
log.info(f"Working directory: {os.getcwd()}")
# create environments
envs = self.env_factory(self.env_config)
log.info(f"Created {envs}")
self._set_seed()
# initialize persistence
additional_persistence = PersistenceGroup(*envs.persistence)
policy_persistence = PolicyPersistence(additional_persistence)
# create environments
envs = self.env_factory(self.env_config)
log.info(f"Created {envs}")
# initialize logger
full_config = self._build_config_dict()
full_config.update(envs.info())
logger = self.logger_factory.create_logger(
log_name=experiment_name,
run_id=logger_run_id,
config_dict=full_config,
)
# initialize persistence
additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence)
policy_persistence = PolicyPersistence(additional_persistence, enabled=use_persistence)
if use_persistence:
log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}")
self.save(persistence_dir)
policy = self.agent_factory.create_policy(envs, self.config.device)
# initialize logger
full_config = self._build_config_dict()
full_config.update(envs.info())
if use_persistence:
logger = self.logger_factory.create_logger(
log_dir=persistence_dir,
experiment_name=experiment_name,
run_id=logger_run_id,
config_dict=full_config,
)
else:
logger = LazyLogger()
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
)
world = World(
envs=envs,
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
logger=logger,
restore_directory=self.config.policy_restore_directory
)
if self.config.policy_restore_directory:
policy_persistence.restore(
# create policy and collectors
policy = self.agent_factory.create_policy(envs, self.config.device)
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
world,
self.config.device,
envs,
)
if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer
result = trainer.run()
pprint(result) # TODO logging
if self.config.watch:
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.config.watch_render,
# create context object with all relevant instances (except trainer; added later)
world = World(
envs=envs,
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
logger=logger,
persist_directory=persistence_dir,
restore_directory=self.config.policy_restore_directory,
)
# TODO return result
# restore policy parameters if applicable
if self.config.policy_restore_directory:
policy_persistence.restore(
policy,
world,
self.config.device,
)
# train policy
if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer
trainer_result = trainer.run()
pprint(trainer_result) # TODO logging
# watch agent performance
if self.config.watch:
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.config.watch_render,
)
# TODO return result
@staticmethod
def _watch_agent(
@ -304,7 +362,6 @@ class ExperimentBuilder:
self._logger_factory,
env_config=self._env_config,
)
log.info(f"Created experiment:\n{experiment.pprints()}")
return experiment

View File

@ -1,6 +1,5 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter
@ -11,38 +10,39 @@ from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = TensorboardLogger | WandbLogger
@dataclass
class Logger:
logger: TLogger
log_path: str
class LoggerFactory(ToStringMixin, ABC):
@abstractmethod
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
pass
def create_logger(
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
) -> TLogger:
""":param log_dir: path to the directory in which log data is to be stored
:param experiment_name: the name of the job, which may contain os.path.sep
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
:param config_dict: a dictionary with data that is to be logged
:return: the logger
"""
class DefaultLoggerFactory(LoggerFactory):
def __init__(
self,
log_dir: str = "log",
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
wandb_project: str | None = None,
):
if logger_type == "wandb" and wandb_project is None:
raise ValueError("Must provide 'wandb_project'")
self.log_dir = log_dir
self.logger_type = logger_type
self.wandb_project = wandb_project
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
writer = SummaryWriter(self.log_dir)
def create_logger(
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
) -> TLogger:
writer = SummaryWriter(log_dir)
writer.add_text(
"args",
str(
dict(
log_dir=self.log_dir,
log_dir=log_dir,
logger_type=self.logger_type,
wandb_project=self.wandb_project,
),
@ -52,7 +52,7 @@ class DefaultLoggerFactory(LoggerFactory):
if self.logger_type == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
name=experiment_name.replace(os.path.sep, "__"),
run_id=run_id,
config=config_dict,
project=self.wandb_project,
@ -62,6 +62,4 @@ class DefaultLoggerFactory(LoggerFactory):
logger = TensorboardLogger(writer)
else:
raise ValueError(f"Unknown logger type '{self.logger_type}'")
log_path = os.path.join(self.log_dir, log_name)
os.makedirs(log_path, exist_ok=True)
return Logger(logger=logger, log_path=log_path)
return logger

View File

@ -1,8 +1,9 @@
import logging
import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
import torch
@ -25,11 +26,15 @@ class PersistableConfigProtocol(Protocol):
class PersistEvent(Enum):
"""Enumeration of persistence events that Persistence objects can react to."""
PERSIST_POLICY = "persist_policy"
"""Policy neural network is persisted (new best found)"""
class RestoreEvent(Enum):
"""Enumeration of restoration events that Persistence objects can react to."""
RESTORE_POLICY = "restore_policy"
"""Policy neural network parameters are restored"""
@ -45,10 +50,13 @@ class Persistence(ABC):
class PersistenceGroup(Persistence):
def __init__(self, *p: Persistence):
def __init__(self, *p: Persistence, enabled=True):
self.items = p
self.enabled = enabled
def persist(self, event: PersistEvent, world: World) -> None:
if not self.enabled:
return
for item in self.items:
item.persist(event, world)
@ -60,10 +68,17 @@ class PersistenceGroup(Persistence):
class PolicyPersistence:
FILENAME = "policy.dat"
def __init__(self, additional_persistence: Persistence | None = None):
def __init__(self, additional_persistence: Persistence | None = None, enabled=True):
""":param additional_persistence: a persistence instance which is to be envoked whenever
this object is used to persist/restore data
:param enabled: whether persistence is enabled (restoration is always enabled)
"""
self.additional_persistence = additional_persistence
self.enabled = enabled
def persist(self, policy: torch.nn.Module, world: World) -> None:
if not self.enabled:
return
path = world.persist_path(self.FILENAME)
log.info(f"Saving policy in {path}")
torch.save(policy.state_dict(), path)

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import TypeVar
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy
from tianshou.utils.string import ToStringMixin
@ -12,7 +12,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class TrainingContext:
def __init__(self, policy: TPolicy, envs: Environments, logger: Logger):
def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger):
self.policy = policy
self.envs = envs
self.logger = logger

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tianshou.data import Collector
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
@ -16,15 +16,11 @@ class World:
policy: "BasePolicy"
train_collector: "Collector"
test_collector: "Collector"
logger: "Logger"
logger: "TLogger"
persist_directory: str
restore_directory: str
trainer: Optional["BaseTrainer"] = None
@property
def persist_directory(self):
return self.logger.log_path
def persist_path(self, filename: str) -> str:
return os.path.join(self.persist_directory, filename)