diff --git a/examples/atari/atari_callbacks.py b/examples/atari/atari_callbacks.py index 41b4f83..d0b4315 100644 --- a/examples/atari/atari_callbacks.py +++ b/examples/atari/atari_callbacks.py @@ -22,7 +22,7 @@ class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain): def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: policy: DQNPolicy = context.policy - logger = context.logger.logger + logger = context.logger # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index b2722cd..61a8fdf 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -81,7 +81,7 @@ def main( class TrainEpochCallback(TrainerEpochCallbackTrain): def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: policy: DQNPolicy = context.policy - logger = context.logger.logger + logger = context.logger # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index fe08529..cecbc01 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,13 +1,13 @@ +import logging import pickle import warnings -import logging import gymnasium as gym from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory -from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent +from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent from tianshou.highlevel.world import World try: diff --git a/test/highlevel/test_continuous.py b/test/highlevel/test_continuous.py index 2b7e927..06995a7 100644 --- a/test/highlevel/test_continuous.py +++ b/test/highlevel/test_continuous.py @@ -33,7 +33,7 @@ from tianshou.highlevel.experiment import ( def test_experiment_builder_continuous_default_params(builder_cls): env_factory = ContinuousTestEnvFactory() sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) - experiment_config = ExperimentConfig() + experiment_config = ExperimentConfig(persistence_enabled=False) builder = builder_cls( experiment_config=experiment_config, env_factory=env_factory, diff --git a/test/highlevel/test_discrete.py b/test/highlevel/test_discrete.py index 52517aa..d9592af 100644 --- a/test/highlevel/test_discrete.py +++ b/test/highlevel/test_discrete.py @@ -11,7 +11,6 @@ from tianshou.highlevel.experiment import ( IQNExperimentBuilder, PPOExperimentBuilder, ) -from tianshou.utils import logging @pytest.mark.parametrize( @@ -25,11 +24,10 @@ from tianshou.utils import logging ], ) def test_experiment_builder_discrete_default_params(builder_cls): - logging.configure() env_factory = DiscreteTestEnvFactory() sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100) builder = builder_cls( - experiment_config=ExperimentConfig(), + experiment_config=ExperimentConfig(persistence_enabled=False), env_factory=env_factory, sampling_config=sampling_config, ) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 116173d..f9e493b 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -171,7 +171,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC): batch_size=sampling_config.batch_size, step_per_collect=sampling_config.step_per_collect, save_best_fn=policy_persistence.get_save_best_fn(world), - logger=world.logger.logger, + logger=world.logger, test_in_train=False, train_fn=train_fn, test_fn=test_fn, @@ -211,7 +211,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC): episode_per_test=sampling_config.num_test_envs, batch_size=sampling_config.batch_size, save_best_fn=policy_persistence.get_save_best_fn(world), - logger=world.logger.logger, + logger=world.logger, update_per_step=sampling_config.update_per_step, test_in_train=False, train_fn=train_fn, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 0c9a91d..8394111 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,5 +1,5 @@ import os -import logging +import pickle from abc import abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass @@ -65,7 +65,11 @@ from tianshou.highlevel.params.policy_params import ( TRPOParams, ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory -from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence, PersistenceGroup +from tianshou.highlevel.persistence import ( + PersistableConfigProtocol, + PersistenceGroup, + PolicyPersistence, +) from tianshou.highlevel.trainer import ( TrainerCallbacks, TrainerEpochCallbackTest, @@ -75,6 +79,7 @@ from tianshou.highlevel.trainer import ( from tianshou.highlevel.world import World from tianshou.policy import BasePolicy from tianshou.trainer import BaseTrainer +from tianshou.utils import LazyLogger, logging from tianshou.utils.logging import datetime_tag from tianshou.utils.string import ToStringMixin @@ -93,7 +98,7 @@ class ExperimentConfig: device: TDevice = "cuda" if torch.cuda.is_available() else "cpu" """The torch device to use""" policy_restore_directory: str | None = None - """Directory from which to load the policy neural network parameters (saved in a previous run)""" + """Directory from which to load the policy neural network parameters (persistence directory of a previous run)""" train: bool = True """Whether to perform training""" watch: bool = True @@ -102,9 +107,24 @@ class ExperimentConfig: """Number of episodes for which to watch performance (if watch is enabled)""" watch_render: float = 0.0 """Milliseconds between rendered frames when watching agent performance (if watch is enabled)""" + persistence_base_dir: str = "log" + """Base directory in which experiment data is to be stored. Every experiment run will create a subdirectory + in this directory based on the run's experiment name""" + persistence_enabled: bool = True + """Whether persistence is enabled, allowing files to be stored""" class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): + """Represents a reinforcement learning experiment. + + An experiment is composed only of configuration and factory objects, which themselves + should be designed to contain only configuration. Therefore, experiments can easily + be stored/pickled and later restored without any problems. + """ + + LOG_FILENAME = "log.txt" + EXPERIMENT_PICKLE_FILENAME = "experiment.pkl" + def __init__( self, config: ExperimentConfig, @@ -121,15 +141,31 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): self.logger_factory = logger_factory self.env_config = env_config + @classmethod + def from_directory(cls, directory: str) -> Self: + """Restores an experiment from a previously stored pickle. + + :param directory: persistence directory of a previous run, in which a pickled experiment is found + """ + with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f: + return pickle.load(f) + def _set_seed(self) -> None: seed = self.config.seed + log.info(f"Setting random seed {seed}") np.random.seed(seed) torch.manual_seed(seed) def _build_config_dict(self) -> dict: - return { - "experiment": self.pprints() - } + return {"experiment": self.pprints()} + + def save(self, directory: str): + path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME) + log.info( + f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')", + ) + with open(path, "wb") as f: + pickle.dump(self, f) def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None: """:param experiment_name: the experiment name, which corresponds to the directory (within the logging @@ -144,66 +180,88 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): if experiment_name is None: experiment_name = datetime_tag() - log.info(f"Working directory: {os.getcwd()}") + # initialize persistence directory + use_persistence = self.config.persistence_enabled + persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name) + if use_persistence: + os.makedirs(persistence_dir, exist_ok=True) - self._set_seed() + with logging.FileLoggerContext( + os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence, + ): + # log initial information + log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}") + log.info(f"Working directory: {os.getcwd()}") - # create environments - envs = self.env_factory(self.env_config) - log.info(f"Created {envs}") + self._set_seed() - # initialize persistence - additional_persistence = PersistenceGroup(*envs.persistence) - policy_persistence = PolicyPersistence(additional_persistence) + # create environments + envs = self.env_factory(self.env_config) + log.info(f"Created {envs}") - # initialize logger - full_config = self._build_config_dict() - full_config.update(envs.info()) - logger = self.logger_factory.create_logger( - log_name=experiment_name, - run_id=logger_run_id, - config_dict=full_config, - ) + # initialize persistence + additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence) + policy_persistence = PolicyPersistence(additional_persistence, enabled=use_persistence) + if use_persistence: + log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}") + self.save(persistence_dir) - policy = self.agent_factory.create_policy(envs, self.config.device) + # initialize logger + full_config = self._build_config_dict() + full_config.update(envs.info()) + if use_persistence: + logger = self.logger_factory.create_logger( + log_dir=persistence_dir, + experiment_name=experiment_name, + run_id=logger_run_id, + config_dict=full_config, + ) + else: + logger = LazyLogger() - train_collector, test_collector = self.agent_factory.create_train_test_collector( - policy, - envs, - ) - - world = World( - envs=envs, - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - logger=logger, - restore_directory=self.config.policy_restore_directory - ) - - if self.config.policy_restore_directory: - policy_persistence.restore( + # create policy and collectors + policy = self.agent_factory.create_policy(envs, self.config.device) + train_collector, test_collector = self.agent_factory.create_train_test_collector( policy, - world, - self.config.device, + envs, ) - if self.config.train: - trainer = self.agent_factory.create_trainer(world, policy_persistence) - world.trainer = trainer - - result = trainer.run() - pprint(result) # TODO logging - - if self.config.watch: - self._watch_agent( - self.config.watch_num_episodes, - policy, - test_collector, - self.config.watch_render, + # create context object with all relevant instances (except trainer; added later) + world = World( + envs=envs, + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + logger=logger, + persist_directory=persistence_dir, + restore_directory=self.config.policy_restore_directory, ) - # TODO return result + # restore policy parameters if applicable + if self.config.policy_restore_directory: + policy_persistence.restore( + policy, + world, + self.config.device, + ) + + # train policy + if self.config.train: + trainer = self.agent_factory.create_trainer(world, policy_persistence) + world.trainer = trainer + trainer_result = trainer.run() + pprint(trainer_result) # TODO logging + + # watch agent performance + if self.config.watch: + self._watch_agent( + self.config.watch_num_episodes, + policy, + test_collector, + self.config.watch_render, + ) + + # TODO return result @staticmethod def _watch_agent( @@ -304,7 +362,6 @@ class ExperimentBuilder: self._logger_factory, env_config=self._env_config, ) - log.info(f"Created experiment:\n{experiment.pprints()}") return experiment diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index c913a9b..6260977 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -1,6 +1,5 @@ import os from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Literal, TypeAlias from torch.utils.tensorboard import SummaryWriter @@ -11,38 +10,39 @@ from tianshou.utils.string import ToStringMixin TLogger: TypeAlias = TensorboardLogger | WandbLogger -@dataclass -class Logger: - logger: TLogger - log_path: str - - class LoggerFactory(ToStringMixin, ABC): @abstractmethod - def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger: - pass + def create_logger( + self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict, + ) -> TLogger: + """:param log_dir: path to the directory in which log data is to be stored + :param experiment_name: the name of the job, which may contain os.path.sep + :param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger + :param config_dict: a dictionary with data that is to be logged + :return: the logger + """ class DefaultLoggerFactory(LoggerFactory): def __init__( self, - log_dir: str = "log", logger_type: Literal["tensorboard", "wandb"] = "tensorboard", wandb_project: str | None = None, ): if logger_type == "wandb" and wandb_project is None: raise ValueError("Must provide 'wandb_project'") - self.log_dir = log_dir self.logger_type = logger_type self.wandb_project = wandb_project - def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger: - writer = SummaryWriter(self.log_dir) + def create_logger( + self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict, + ) -> TLogger: + writer = SummaryWriter(log_dir) writer.add_text( "args", str( dict( - log_dir=self.log_dir, + log_dir=log_dir, logger_type=self.logger_type, wandb_project=self.wandb_project, ), @@ -52,7 +52,7 @@ class DefaultLoggerFactory(LoggerFactory): if self.logger_type == "wandb": logger = WandbLogger( save_interval=1, - name=log_name.replace(os.path.sep, "__"), + name=experiment_name.replace(os.path.sep, "__"), run_id=run_id, config=config_dict, project=self.wandb_project, @@ -62,6 +62,4 @@ class DefaultLoggerFactory(LoggerFactory): logger = TensorboardLogger(writer) else: raise ValueError(f"Unknown logger type '{self.logger_type}'") - log_path = os.path.join(self.log_dir, log_name) - os.makedirs(log_path, exist_ok=True) - return Logger(logger=logger, log_path=log_path) + return logger diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 4e4f606..225b1e8 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -1,8 +1,9 @@ import logging import os from abc import ABC, abstractmethod +from collections.abc import Callable from enum import Enum -from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable, Callable +from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable import torch @@ -25,11 +26,15 @@ class PersistableConfigProtocol(Protocol): class PersistEvent(Enum): + """Enumeration of persistence events that Persistence objects can react to.""" + PERSIST_POLICY = "persist_policy" """Policy neural network is persisted (new best found)""" class RestoreEvent(Enum): + """Enumeration of restoration events that Persistence objects can react to.""" + RESTORE_POLICY = "restore_policy" """Policy neural network parameters are restored""" @@ -45,10 +50,13 @@ class Persistence(ABC): class PersistenceGroup(Persistence): - def __init__(self, *p: Persistence): + def __init__(self, *p: Persistence, enabled=True): self.items = p + self.enabled = enabled def persist(self, event: PersistEvent, world: World) -> None: + if not self.enabled: + return for item in self.items: item.persist(event, world) @@ -60,10 +68,17 @@ class PersistenceGroup(Persistence): class PolicyPersistence: FILENAME = "policy.dat" - def __init__(self, additional_persistence: Persistence | None = None): + def __init__(self, additional_persistence: Persistence | None = None, enabled=True): + """:param additional_persistence: a persistence instance which is to be envoked whenever + this object is used to persist/restore data + :param enabled: whether persistence is enabled (restoration is always enabled) + """ self.additional_persistence = additional_persistence + self.enabled = enabled def persist(self, policy: torch.nn.Module, world: World) -> None: + if not self.enabled: + return path = world.persist_path(self.FILENAME) log.info(f"Saving policy in {path}") torch.save(policy.state_dict(), path) diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 876752b..770c92b 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import TypeVar from tianshou.highlevel.env import Environments -from tianshou.highlevel.logger import Logger +from tianshou.highlevel.logger import TLogger from tianshou.policy import BasePolicy from tianshou.utils.string import ToStringMixin @@ -12,7 +12,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy) class TrainingContext: - def __init__(self, policy: TPolicy, envs: Environments, logger: Logger): + def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger): self.policy = policy self.envs = envs self.logger = logger diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index c5cb369..aa278db 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from tianshou.data import Collector from tianshou.highlevel.env import Environments - from tianshou.highlevel.logger import Logger + from tianshou.highlevel.logger import TLogger from tianshou.policy import BasePolicy from tianshou.trainer import BaseTrainer @@ -16,15 +16,11 @@ class World: policy: "BasePolicy" train_collector: "Collector" test_collector: "Collector" - logger: "Logger" + logger: "TLogger" + persist_directory: str restore_directory: str trainer: Optional["BaseTrainer"] = None - - @property - def persist_directory(self): - return self.logger.log_path - def persist_path(self, filename: str) -> str: return os.path.join(self.persist_directory, filename)