Improve persistence handling

* Add persistence/restoration of Experiment instance
* Add file logging in experiment
* Allow all persistence/logging to be disabled
* Disable persistence in tests
This commit is contained in:
Dominik Jain 2023-10-12 17:40:16 +02:00
parent ba803296cc
commit 76e870207d
11 changed files with 160 additions and 96 deletions

@ -22,7 +22,7 @@ class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy policy: DQNPolicy = context.policy
logger = context.logger.logger logger = context.logger
# nature DQN setting, linear decay in the first 1M steps # nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6: if env_step <= 1e6:
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final) eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)

@ -81,7 +81,7 @@ def main(
class TrainEpochCallback(TrainerEpochCallbackTrain): class TrainEpochCallback(TrainerEpochCallbackTrain):
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy policy: DQNPolicy = context.policy
logger = context.logger.logger logger = context.logger
# nature DQN setting, linear decay in the first 1M steps # nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6: if env_step <= 1e6:
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)

@ -1,13 +1,13 @@
import logging
import pickle import pickle
import warnings import warnings
import logging
import gymnasium as gym import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World from tianshou.highlevel.world import World
try: try:

@ -33,7 +33,7 @@ from tianshou.highlevel.experiment import (
def test_experiment_builder_continuous_default_params(builder_cls): def test_experiment_builder_continuous_default_params(builder_cls):
env_factory = ContinuousTestEnvFactory() env_factory = ContinuousTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
experiment_config = ExperimentConfig() experiment_config = ExperimentConfig(persistence_enabled=False)
builder = builder_cls( builder = builder_cls(
experiment_config=experiment_config, experiment_config=experiment_config,
env_factory=env_factory, env_factory=env_factory,

@ -11,7 +11,6 @@ from tianshou.highlevel.experiment import (
IQNExperimentBuilder, IQNExperimentBuilder,
PPOExperimentBuilder, PPOExperimentBuilder,
) )
from tianshou.utils import logging
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -25,11 +24,10 @@ from tianshou.utils import logging
], ],
) )
def test_experiment_builder_discrete_default_params(builder_cls): def test_experiment_builder_discrete_default_params(builder_cls):
logging.configure()
env_factory = DiscreteTestEnvFactory() env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
builder = builder_cls( builder = builder_cls(
experiment_config=ExperimentConfig(), experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory, env_factory=env_factory,
sampling_config=sampling_config, sampling_config=sampling_config,
) )

@ -171,7 +171,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
batch_size=sampling_config.batch_size, batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect, step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world), save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger.logger, logger=world.logger,
test_in_train=False, test_in_train=False,
train_fn=train_fn, train_fn=train_fn,
test_fn=test_fn, test_fn=test_fn,
@ -211,7 +211,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
episode_per_test=sampling_config.num_test_envs, episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size, batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world), save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger.logger, logger=world.logger,
update_per_step=sampling_config.update_per_step, update_per_step=sampling_config.update_per_step,
test_in_train=False, test_in_train=False,
train_fn=train_fn, train_fn=train_fn,

@ -1,5 +1,5 @@
import os import os
import logging import pickle
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
@ -65,7 +65,11 @@ from tianshou.highlevel.params.policy_params import (
TRPOParams, TRPOParams,
) )
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup from tianshou.highlevel.persistence import (
PersistableConfigProtocol,
PersistenceGroup,
PolicyPersistence,
)
from tianshou.highlevel.trainer import ( from tianshou.highlevel.trainer import (
TrainerCallbacks, TrainerCallbacks,
TrainerEpochCallbackTest, TrainerEpochCallbackTest,
@ -75,6 +79,7 @@ from tianshou.highlevel.trainer import (
from tianshou.highlevel.world import World from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
from tianshou.utils import LazyLogger, logging
from tianshou.utils.logging import datetime_tag from tianshou.utils.logging import datetime_tag
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -93,7 +98,7 @@ class ExperimentConfig:
device: TDevice = "cuda" if torch.cuda.is_available() else "cpu" device: TDevice = "cuda" if torch.cuda.is_available() else "cpu"
"""The torch device to use""" """The torch device to use"""
policy_restore_directory: str | None = None policy_restore_directory: str | None = None
"""Directory from which to load the policy neural network parameters (saved in a previous run)""" """Directory from which to load the policy neural network parameters (persistence directory of a previous run)"""
train: bool = True train: bool = True
"""Whether to perform training""" """Whether to perform training"""
watch: bool = True watch: bool = True
@ -102,9 +107,24 @@ class ExperimentConfig:
"""Number of episodes for which to watch performance (if watch is enabled)""" """Number of episodes for which to watch performance (if watch is enabled)"""
watch_render: float = 0.0 watch_render: float = 0.0
"""Milliseconds between rendered frames when watching agent performance (if watch is enabled)""" """Milliseconds between rendered frames when watching agent performance (if watch is enabled)"""
persistence_base_dir: str = "log"
"""Base directory in which experiment data is to be stored. Every experiment run will create a subdirectory
in this directory based on the run's experiment name"""
persistence_enabled: bool = True
"""Whether persistence is enabled, allowing files to be stored"""
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
"""Represents a reinforcement learning experiment.
An experiment is composed only of configuration and factory objects, which themselves
should be designed to contain only configuration. Therefore, experiments can easily
be stored/pickled and later restored without any problems.
"""
LOG_FILENAME = "log.txt"
EXPERIMENT_PICKLE_FILENAME = "experiment.pkl"
def __init__( def __init__(
self, self,
config: ExperimentConfig, config: ExperimentConfig,
@ -121,15 +141,31 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
self.logger_factory = logger_factory self.logger_factory = logger_factory
self.env_config = env_config self.env_config = env_config
@classmethod
def from_directory(cls, directory: str) -> Self:
"""Restores an experiment from a previously stored pickle.
:param directory: persistence directory of a previous run, in which a pickled experiment is found
"""
with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f:
return pickle.load(f)
def _set_seed(self) -> None: def _set_seed(self) -> None:
seed = self.config.seed seed = self.config.seed
log.info(f"Setting random seed {seed}")
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
def _build_config_dict(self) -> dict: def _build_config_dict(self) -> dict:
return { return {"experiment": self.pprints()}
"experiment": self.pprints()
} def save(self, directory: str):
path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME)
log.info(
f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')",
)
with open(path, "wb") as f:
pickle.dump(self, f)
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None: def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging """:param experiment_name: the experiment name, which corresponds to the directory (within the logging
@ -144,6 +180,17 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
if experiment_name is None: if experiment_name is None:
experiment_name = datetime_tag() experiment_name = datetime_tag()
# initialize persistence directory
use_persistence = self.config.persistence_enabled
persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name)
if use_persistence:
os.makedirs(persistence_dir, exist_ok=True)
with logging.FileLoggerContext(
os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence,
):
# log initial information
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
log.info(f"Working directory: {os.getcwd()}") log.info(f"Working directory: {os.getcwd()}")
self._set_seed() self._set_seed()
@ -153,34 +200,44 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
log.info(f"Created {envs}") log.info(f"Created {envs}")
# initialize persistence # initialize persistence
additional_persistence = PersistenceGroup(*envs.persistence) additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence)
policy_persistence = PolicyPersistence(additional_persistence) policy_persistence = PolicyPersistence(additional_persistence, enabled=use_persistence)
if use_persistence:
log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}")
self.save(persistence_dir)
# initialize logger # initialize logger
full_config = self._build_config_dict() full_config = self._build_config_dict()
full_config.update(envs.info()) full_config.update(envs.info())
if use_persistence:
logger = self.logger_factory.create_logger( logger = self.logger_factory.create_logger(
log_name=experiment_name, log_dir=persistence_dir,
experiment_name=experiment_name,
run_id=logger_run_id, run_id=logger_run_id,
config_dict=full_config, config_dict=full_config,
) )
else:
logger = LazyLogger()
# create policy and collectors
policy = self.agent_factory.create_policy(envs, self.config.device) policy = self.agent_factory.create_policy(envs, self.config.device)
train_collector, test_collector = self.agent_factory.create_train_test_collector( train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy, policy,
envs, envs,
) )
# create context object with all relevant instances (except trainer; added later)
world = World( world = World(
envs=envs, envs=envs,
policy=policy, policy=policy,
train_collector=train_collector, train_collector=train_collector,
test_collector=test_collector, test_collector=test_collector,
logger=logger, logger=logger,
restore_directory=self.config.policy_restore_directory persist_directory=persistence_dir,
restore_directory=self.config.policy_restore_directory,
) )
# restore policy parameters if applicable
if self.config.policy_restore_directory: if self.config.policy_restore_directory:
policy_persistence.restore( policy_persistence.restore(
policy, policy,
@ -188,13 +245,14 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
self.config.device, self.config.device,
) )
# train policy
if self.config.train: if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence) trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer world.trainer = trainer
trainer_result = trainer.run()
pprint(trainer_result) # TODO logging
result = trainer.run() # watch agent performance
pprint(result) # TODO logging
if self.config.watch: if self.config.watch:
self._watch_agent( self._watch_agent(
self.config.watch_num_episodes, self.config.watch_num_episodes,
@ -304,7 +362,6 @@ class ExperimentBuilder:
self._logger_factory, self._logger_factory,
env_config=self._env_config, env_config=self._env_config,
) )
log.info(f"Created experiment:\n{experiment.pprints()}")
return experiment return experiment

@ -1,6 +1,5 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Literal, TypeAlias from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -11,38 +10,39 @@ from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = TensorboardLogger | WandbLogger TLogger: TypeAlias = TensorboardLogger | WandbLogger
@dataclass
class Logger:
logger: TLogger
log_path: str
class LoggerFactory(ToStringMixin, ABC): class LoggerFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger: def create_logger(
pass self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
) -> TLogger:
""":param log_dir: path to the directory in which log data is to be stored
:param experiment_name: the name of the job, which may contain os.path.sep
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
:param config_dict: a dictionary with data that is to be logged
:return: the logger
"""
class DefaultLoggerFactory(LoggerFactory): class DefaultLoggerFactory(LoggerFactory):
def __init__( def __init__(
self, self,
log_dir: str = "log",
logger_type: Literal["tensorboard", "wandb"] = "tensorboard", logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
wandb_project: str | None = None, wandb_project: str | None = None,
): ):
if logger_type == "wandb" and wandb_project is None: if logger_type == "wandb" and wandb_project is None:
raise ValueError("Must provide 'wandb_project'") raise ValueError("Must provide 'wandb_project'")
self.log_dir = log_dir
self.logger_type = logger_type self.logger_type = logger_type
self.wandb_project = wandb_project self.wandb_project = wandb_project
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger: def create_logger(
writer = SummaryWriter(self.log_dir) self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
) -> TLogger:
writer = SummaryWriter(log_dir)
writer.add_text( writer.add_text(
"args", "args",
str( str(
dict( dict(
log_dir=self.log_dir, log_dir=log_dir,
logger_type=self.logger_type, logger_type=self.logger_type,
wandb_project=self.wandb_project, wandb_project=self.wandb_project,
), ),
@ -52,7 +52,7 @@ class DefaultLoggerFactory(LoggerFactory):
if self.logger_type == "wandb": if self.logger_type == "wandb":
logger = WandbLogger( logger = WandbLogger(
save_interval=1, save_interval=1,
name=log_name.replace(os.path.sep, "__"), name=experiment_name.replace(os.path.sep, "__"),
run_id=run_id, run_id=run_id,
config=config_dict, config=config_dict,
project=self.wandb_project, project=self.wandb_project,
@ -62,6 +62,4 @@ class DefaultLoggerFactory(LoggerFactory):
logger = TensorboardLogger(writer) logger = TensorboardLogger(writer)
else: else:
raise ValueError(f"Unknown logger type '{self.logger_type}'") raise ValueError(f"Unknown logger type '{self.logger_type}'")
log_path = os.path.join(self.log_dir, log_name) return logger
os.makedirs(log_path, exist_ok=True)
return Logger(logger=logger, log_path=log_path)

@ -1,8 +1,9 @@
import logging import logging
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
import torch import torch
@ -25,11 +26,15 @@ class PersistableConfigProtocol(Protocol):
class PersistEvent(Enum): class PersistEvent(Enum):
"""Enumeration of persistence events that Persistence objects can react to."""
PERSIST_POLICY = "persist_policy" PERSIST_POLICY = "persist_policy"
"""Policy neural network is persisted (new best found)""" """Policy neural network is persisted (new best found)"""
class RestoreEvent(Enum): class RestoreEvent(Enum):
"""Enumeration of restoration events that Persistence objects can react to."""
RESTORE_POLICY = "restore_policy" RESTORE_POLICY = "restore_policy"
"""Policy neural network parameters are restored""" """Policy neural network parameters are restored"""
@ -45,10 +50,13 @@ class Persistence(ABC):
class PersistenceGroup(Persistence): class PersistenceGroup(Persistence):
def __init__(self, *p: Persistence): def __init__(self, *p: Persistence, enabled=True):
self.items = p self.items = p
self.enabled = enabled
def persist(self, event: PersistEvent, world: World) -> None: def persist(self, event: PersistEvent, world: World) -> None:
if not self.enabled:
return
for item in self.items: for item in self.items:
item.persist(event, world) item.persist(event, world)
@ -60,10 +68,17 @@ class PersistenceGroup(Persistence):
class PolicyPersistence: class PolicyPersistence:
FILENAME = "policy.dat" FILENAME = "policy.dat"
def __init__(self, additional_persistence: Persistence | None = None): def __init__(self, additional_persistence: Persistence | None = None, enabled=True):
""":param additional_persistence: a persistence instance which is to be envoked whenever
this object is used to persist/restore data
:param enabled: whether persistence is enabled (restoration is always enabled)
"""
self.additional_persistence = additional_persistence self.additional_persistence = additional_persistence
self.enabled = enabled
def persist(self, policy: torch.nn.Module, world: World) -> None: def persist(self, policy: torch.nn.Module, world: World) -> None:
if not self.enabled:
return
path = world.persist_path(self.FILENAME) path = world.persist_path(self.FILENAME)
log.info(f"Saving policy in {path}") log.info(f"Saving policy in {path}")
torch.save(policy.state_dict(), path) torch.save(policy.state_dict(), path)

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import TypeVar from typing import TypeVar
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -12,7 +12,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class TrainingContext: class TrainingContext:
def __init__(self, policy: TPolicy, envs: Environments, logger: Logger): def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger):
self.policy = policy self.policy = policy
self.envs = envs self.envs = envs
self.logger = logger self.logger = logger

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
@ -16,15 +16,11 @@ class World:
policy: "BasePolicy" policy: "BasePolicy"
train_collector: "Collector" train_collector: "Collector"
test_collector: "Collector" test_collector: "Collector"
logger: "Logger" logger: "TLogger"
persist_directory: str
restore_directory: str restore_directory: str
trainer: Optional["BaseTrainer"] = None trainer: Optional["BaseTrainer"] = None
@property
def persist_directory(self):
return self.logger.log_path
def persist_path(self, filename: str) -> str: def persist_path(self, filename: str) -> str:
return os.path.join(self.persist_directory, filename) return os.path.join(self.persist_directory, filename)