Improve persistence handling
* Add persistence/restoration of Experiment instance * Add file logging in experiment * Allow all persistence/logging to be disabled * Disable persistence in tests
This commit is contained in:
parent
ba803296cc
commit
76e870207d
@ -22,7 +22,7 @@ class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
logger = context.logger.logger
|
||||
logger = context.logger
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
|
||||
|
@ -81,7 +81,7 @@ def main(
|
||||
class TrainEpochCallback(TrainerEpochCallbackTrain):
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
logger = context.logger.logger
|
||||
logger = context.logger
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
|
||||
|
@ -1,13 +1,13 @@
|
||||
import logging
|
||||
import pickle
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||
from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent
|
||||
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
||||
from tianshou.highlevel.world import World
|
||||
|
||||
try:
|
||||
|
@ -33,7 +33,7 @@ from tianshou.highlevel.experiment import (
|
||||
def test_experiment_builder_continuous_default_params(builder_cls):
|
||||
env_factory = ContinuousTestEnvFactory()
|
||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
experiment_config = ExperimentConfig()
|
||||
experiment_config = ExperimentConfig(persistence_enabled=False)
|
||||
builder = builder_cls(
|
||||
experiment_config=experiment_config,
|
||||
env_factory=env_factory,
|
||||
|
@ -11,7 +11,6 @@ from tianshou.highlevel.experiment import (
|
||||
IQNExperimentBuilder,
|
||||
PPOExperimentBuilder,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -25,11 +24,10 @@ from tianshou.utils import logging
|
||||
],
|
||||
)
|
||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||
logging.configure()
|
||||
env_factory = DiscreteTestEnvFactory()
|
||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
builder = builder_cls(
|
||||
experiment_config=ExperimentConfig(),
|
||||
experiment_config=ExperimentConfig(persistence_enabled=False),
|
||||
env_factory=env_factory,
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
|
@ -171,7 +171,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
batch_size=sampling_config.batch_size,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
logger=world.logger.logger,
|
||||
logger=world.logger,
|
||||
test_in_train=False,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
@ -211,7 +211,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
batch_size=sampling_config.batch_size,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
logger=world.logger.logger,
|
||||
logger=world.logger,
|
||||
update_per_step=sampling_config.update_per_step,
|
||||
test_in_train=False,
|
||||
train_fn=train_fn,
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
import logging
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -65,7 +65,11 @@ from tianshou.highlevel.params.policy_params import (
|
||||
TRPOParams,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup
|
||||
from tianshou.highlevel.persistence import (
|
||||
PersistableConfigProtocol,
|
||||
PersistenceGroup,
|
||||
PolicyPersistence,
|
||||
)
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerCallbacks,
|
||||
TrainerEpochCallbackTest,
|
||||
@ -75,6 +79,7 @@ from tianshou.highlevel.trainer import (
|
||||
from tianshou.highlevel.world import World
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import BaseTrainer
|
||||
from tianshou.utils import LazyLogger, logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
@ -93,7 +98,7 @@ class ExperimentConfig:
|
||||
device: TDevice = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
"""The torch device to use"""
|
||||
policy_restore_directory: str | None = None
|
||||
"""Directory from which to load the policy neural network parameters (saved in a previous run)"""
|
||||
"""Directory from which to load the policy neural network parameters (persistence directory of a previous run)"""
|
||||
train: bool = True
|
||||
"""Whether to perform training"""
|
||||
watch: bool = True
|
||||
@ -102,9 +107,24 @@ class ExperimentConfig:
|
||||
"""Number of episodes for which to watch performance (if watch is enabled)"""
|
||||
watch_render: float = 0.0
|
||||
"""Milliseconds between rendered frames when watching agent performance (if watch is enabled)"""
|
||||
persistence_base_dir: str = "log"
|
||||
"""Base directory in which experiment data is to be stored. Every experiment run will create a subdirectory
|
||||
in this directory based on the run's experiment name"""
|
||||
persistence_enabled: bool = True
|
||||
"""Whether persistence is enabled, allowing files to be stored"""
|
||||
|
||||
|
||||
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
"""Represents a reinforcement learning experiment.
|
||||
|
||||
An experiment is composed only of configuration and factory objects, which themselves
|
||||
should be designed to contain only configuration. Therefore, experiments can easily
|
||||
be stored/pickled and later restored without any problems.
|
||||
"""
|
||||
|
||||
LOG_FILENAME = "log.txt"
|
||||
EXPERIMENT_PICKLE_FILENAME = "experiment.pkl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ExperimentConfig,
|
||||
@ -121,15 +141,31 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
self.logger_factory = logger_factory
|
||||
self.env_config = env_config
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, directory: str) -> Self:
|
||||
"""Restores an experiment from a previously stored pickle.
|
||||
|
||||
:param directory: persistence directory of a previous run, in which a pickled experiment is found
|
||||
"""
|
||||
with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def _set_seed(self) -> None:
|
||||
seed = self.config.seed
|
||||
log.info(f"Setting random seed {seed}")
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def _build_config_dict(self) -> dict:
|
||||
return {
|
||||
"experiment": self.pprints()
|
||||
}
|
||||
return {"experiment": self.pprints()}
|
||||
|
||||
def save(self, directory: str):
|
||||
path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME)
|
||||
log.info(
|
||||
f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')",
|
||||
)
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||
@ -144,66 +180,88 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
if experiment_name is None:
|
||||
experiment_name = datetime_tag()
|
||||
|
||||
log.info(f"Working directory: {os.getcwd()}")
|
||||
# initialize persistence directory
|
||||
use_persistence = self.config.persistence_enabled
|
||||
persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name)
|
||||
if use_persistence:
|
||||
os.makedirs(persistence_dir, exist_ok=True)
|
||||
|
||||
self._set_seed()
|
||||
with logging.FileLoggerContext(
|
||||
os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence,
|
||||
):
|
||||
# log initial information
|
||||
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
|
||||
log.info(f"Working directory: {os.getcwd()}")
|
||||
|
||||
# create environments
|
||||
envs = self.env_factory(self.env_config)
|
||||
log.info(f"Created {envs}")
|
||||
self._set_seed()
|
||||
|
||||
# initialize persistence
|
||||
additional_persistence = PersistenceGroup(*envs.persistence)
|
||||
policy_persistence = PolicyPersistence(additional_persistence)
|
||||
# create environments
|
||||
envs = self.env_factory(self.env_config)
|
||||
log.info(f"Created {envs}")
|
||||
|
||||
# initialize logger
|
||||
full_config = self._build_config_dict()
|
||||
full_config.update(envs.info())
|
||||
logger = self.logger_factory.create_logger(
|
||||
log_name=experiment_name,
|
||||
run_id=logger_run_id,
|
||||
config_dict=full_config,
|
||||
)
|
||||
# initialize persistence
|
||||
additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence)
|
||||
policy_persistence = PolicyPersistence(additional_persistence, enabled=use_persistence)
|
||||
if use_persistence:
|
||||
log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}")
|
||||
self.save(persistence_dir)
|
||||
|
||||
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||
# initialize logger
|
||||
full_config = self._build_config_dict()
|
||||
full_config.update(envs.info())
|
||||
if use_persistence:
|
||||
logger = self.logger_factory.create_logger(
|
||||
log_dir=persistence_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_id=logger_run_id,
|
||||
config_dict=full_config,
|
||||
)
|
||||
else:
|
||||
logger = LazyLogger()
|
||||
|
||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
||||
policy,
|
||||
envs,
|
||||
)
|
||||
|
||||
world = World(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
logger=logger,
|
||||
restore_directory=self.config.policy_restore_directory
|
||||
)
|
||||
|
||||
if self.config.policy_restore_directory:
|
||||
policy_persistence.restore(
|
||||
# create policy and collectors
|
||||
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
||||
policy,
|
||||
world,
|
||||
self.config.device,
|
||||
envs,
|
||||
)
|
||||
|
||||
if self.config.train:
|
||||
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
||||
world.trainer = trainer
|
||||
|
||||
result = trainer.run()
|
||||
pprint(result) # TODO logging
|
||||
|
||||
if self.config.watch:
|
||||
self._watch_agent(
|
||||
self.config.watch_num_episodes,
|
||||
policy,
|
||||
test_collector,
|
||||
self.config.watch_render,
|
||||
# create context object with all relevant instances (except trainer; added later)
|
||||
world = World(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
logger=logger,
|
||||
persist_directory=persistence_dir,
|
||||
restore_directory=self.config.policy_restore_directory,
|
||||
)
|
||||
|
||||
# TODO return result
|
||||
# restore policy parameters if applicable
|
||||
if self.config.policy_restore_directory:
|
||||
policy_persistence.restore(
|
||||
policy,
|
||||
world,
|
||||
self.config.device,
|
||||
)
|
||||
|
||||
# train policy
|
||||
if self.config.train:
|
||||
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
||||
world.trainer = trainer
|
||||
trainer_result = trainer.run()
|
||||
pprint(trainer_result) # TODO logging
|
||||
|
||||
# watch agent performance
|
||||
if self.config.watch:
|
||||
self._watch_agent(
|
||||
self.config.watch_num_episodes,
|
||||
policy,
|
||||
test_collector,
|
||||
self.config.watch_render,
|
||||
)
|
||||
|
||||
# TODO return result
|
||||
|
||||
@staticmethod
|
||||
def _watch_agent(
|
||||
@ -304,7 +362,6 @@ class ExperimentBuilder:
|
||||
self._logger_factory,
|
||||
env_config=self._env_config,
|
||||
)
|
||||
log.info(f"Created experiment:\n{experiment.pprints()}")
|
||||
return experiment
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -11,38 +10,39 @@ from tianshou.utils.string import ToStringMixin
|
||||
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
class Logger:
|
||||
logger: TLogger
|
||||
log_path: str
|
||||
|
||||
|
||||
class LoggerFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
||||
pass
|
||||
def create_logger(
|
||||
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
|
||||
) -> TLogger:
|
||||
""":param log_dir: path to the directory in which log data is to be stored
|
||||
:param experiment_name: the name of the job, which may contain os.path.sep
|
||||
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
|
||||
:param config_dict: a dictionary with data that is to be logged
|
||||
:return: the logger
|
||||
"""
|
||||
|
||||
|
||||
class DefaultLoggerFactory(LoggerFactory):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: str = "log",
|
||||
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
|
||||
wandb_project: str | None = None,
|
||||
):
|
||||
if logger_type == "wandb" and wandb_project is None:
|
||||
raise ValueError("Must provide 'wandb_project'")
|
||||
self.log_dir = log_dir
|
||||
self.logger_type = logger_type
|
||||
self.wandb_project = wandb_project
|
||||
|
||||
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
||||
writer = SummaryWriter(self.log_dir)
|
||||
def create_logger(
|
||||
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
|
||||
) -> TLogger:
|
||||
writer = SummaryWriter(log_dir)
|
||||
writer.add_text(
|
||||
"args",
|
||||
str(
|
||||
dict(
|
||||
log_dir=self.log_dir,
|
||||
log_dir=log_dir,
|
||||
logger_type=self.logger_type,
|
||||
wandb_project=self.wandb_project,
|
||||
),
|
||||
@ -52,7 +52,7 @@ class DefaultLoggerFactory(LoggerFactory):
|
||||
if self.logger_type == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
name=experiment_name.replace(os.path.sep, "__"),
|
||||
run_id=run_id,
|
||||
config=config_dict,
|
||||
project=self.wandb_project,
|
||||
@ -62,6 +62,4 @@ class DefaultLoggerFactory(LoggerFactory):
|
||||
logger = TensorboardLogger(writer)
|
||||
else:
|
||||
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||
log_path = os.path.join(self.log_dir, log_name)
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
return Logger(logger=logger, log_path=log_path)
|
||||
return logger
|
||||
|
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable
|
||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
|
||||
|
||||
import torch
|
||||
|
||||
@ -25,11 +26,15 @@ class PersistableConfigProtocol(Protocol):
|
||||
|
||||
|
||||
class PersistEvent(Enum):
|
||||
"""Enumeration of persistence events that Persistence objects can react to."""
|
||||
|
||||
PERSIST_POLICY = "persist_policy"
|
||||
"""Policy neural network is persisted (new best found)"""
|
||||
|
||||
|
||||
class RestoreEvent(Enum):
|
||||
"""Enumeration of restoration events that Persistence objects can react to."""
|
||||
|
||||
RESTORE_POLICY = "restore_policy"
|
||||
"""Policy neural network parameters are restored"""
|
||||
|
||||
@ -45,10 +50,13 @@ class Persistence(ABC):
|
||||
|
||||
|
||||
class PersistenceGroup(Persistence):
|
||||
def __init__(self, *p: Persistence):
|
||||
def __init__(self, *p: Persistence, enabled=True):
|
||||
self.items = p
|
||||
self.enabled = enabled
|
||||
|
||||
def persist(self, event: PersistEvent, world: World) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
for item in self.items:
|
||||
item.persist(event, world)
|
||||
|
||||
@ -60,10 +68,17 @@ class PersistenceGroup(Persistence):
|
||||
class PolicyPersistence:
|
||||
FILENAME = "policy.dat"
|
||||
|
||||
def __init__(self, additional_persistence: Persistence | None = None):
|
||||
def __init__(self, additional_persistence: Persistence | None = None, enabled=True):
|
||||
""":param additional_persistence: a persistence instance which is to be envoked whenever
|
||||
this object is used to persist/restore data
|
||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||
"""
|
||||
self.additional_persistence = additional_persistence
|
||||
self.enabled = enabled
|
||||
|
||||
def persist(self, policy: torch.nn.Module, world: World) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
path = world.persist_path(self.FILENAME)
|
||||
log.info(f"Saving policy in {path}")
|
||||
torch.save(policy.state_dict(), path)
|
||||
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.logger import TLogger
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
@ -12,7 +12,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
|
||||
|
||||
class TrainingContext:
|
||||
def __init__(self, policy: TPolicy, envs: Environments, logger: Logger):
|
||||
def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger):
|
||||
self.policy = policy
|
||||
self.envs = envs
|
||||
self.logger = logger
|
||||
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
if TYPE_CHECKING:
|
||||
from tianshou.data import Collector
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.logger import TLogger
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import BaseTrainer
|
||||
|
||||
@ -16,15 +16,11 @@ class World:
|
||||
policy: "BasePolicy"
|
||||
train_collector: "Collector"
|
||||
test_collector: "Collector"
|
||||
logger: "Logger"
|
||||
logger: "TLogger"
|
||||
persist_directory: str
|
||||
restore_directory: str
|
||||
trainer: Optional["BaseTrainer"] = None
|
||||
|
||||
|
||||
@property
|
||||
def persist_directory(self):
|
||||
return self.logger.log_path
|
||||
|
||||
def persist_path(self, filename: str) -> str:
|
||||
return os.path.join(self.persist_directory, filename)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user