diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index b927cda..116173d 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -1,15 +1,12 @@ +import logging from abc import ABC, abstractmethod -from collections.abc import Callable -from os import PathLike from typing import Any, Generic, TypeVar, cast import gymnasium -import torch from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import Environments -from tianshou.highlevel.logger import Logger from tianshou.highlevel.module.actor import ( ActorFactory, ) @@ -41,7 +38,9 @@ from tianshou.highlevel.params.policy_params import ( TRPOParams, ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory +from tianshou.highlevel.persistence import PolicyPersistence from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext +from tianshou.highlevel.world import World from tianshou.policy import ( A2CPolicy, BasePolicy, @@ -71,6 +70,7 @@ TDiscreteCriticOnlyParams = TypeVar( bound=ParamsMixinLearningRateWithScheduler, ) TPolicy = TypeVar("TPolicy", bound=BasePolicy) +log = logging.getLogger(__name__) class AgentFactory(ABC, ToStringMixin): @@ -133,58 +133,20 @@ class AgentFactory(ABC, ToStringMixin): ) return policy - @staticmethod - def _create_save_best_fn(envs: Environments, log_path: str) -> Callable: - def save_best_fn(pol: torch.nn.Module) -> None: - pass - # TODO: Fix saving in general (code works only for mujoco) - # state = { - # CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(), - # CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(), - # } - # torch.save(state, os.path.join(log_path, "policy.pth")) - - return save_best_fn - - @staticmethod - def load_checkpoint( - policy: torch.nn.Module, - path: str | PathLike, - envs: Environments, - device: TDevice, - ) -> None: - ckpt = torch.load(path, map_location=device) - policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL]) - if envs.train_envs: - envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS]) - if envs.test_envs: - envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS]) - print("Loaded agent and obs. running means from: ", path) # TODO logging - @abstractmethod - def create_trainer( - self, - policy: BasePolicy, - train_collector: Collector, - test_collector: Collector, - envs: Environments, - logger: Logger, - ) -> BaseTrainer: + def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> BaseTrainer: pass class OnpolicyAgentFactory(AgentFactory, ABC): def create_trainer( self, - policy: BasePolicy, - train_collector: Collector, - test_collector: Collector, - envs: Environments, - logger: Logger, + world: World, + policy_persistence: PolicyPersistence, ) -> OnpolicyTrainer: sampling_config = self.sampling_config callbacks = self.trainer_callbacks - context = TrainingContext(policy, envs, logger) + context = TrainingContext(world.policy, world.envs, world.logger) train_fn = ( callbacks.epoch_callback_train.get_trainer_fn(context) if callbacks.epoch_callback_train @@ -199,17 +161,17 @@ class OnpolicyAgentFactory(AgentFactory, ABC): callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None ) return OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, + policy=world.policy, + train_collector=world.train_collector, + test_collector=world.test_collector, max_epoch=sampling_config.num_epochs, step_per_epoch=sampling_config.step_per_epoch, repeat_per_collect=sampling_config.repeat_per_collect, episode_per_test=sampling_config.num_test_envs, batch_size=sampling_config.batch_size, step_per_collect=sampling_config.step_per_collect, - save_best_fn=self._create_save_best_fn(envs, logger.log_path), - logger=logger.logger, + save_best_fn=policy_persistence.get_save_best_fn(world), + logger=world.logger.logger, test_in_train=False, train_fn=train_fn, test_fn=test_fn, @@ -220,15 +182,12 @@ class OnpolicyAgentFactory(AgentFactory, ABC): class OffpolicyAgentFactory(AgentFactory, ABC): def create_trainer( self, - policy: BasePolicy, - train_collector: Collector, - test_collector: Collector, - envs: Environments, - logger: Logger, + world: World, + policy_persistence: PolicyPersistence, ) -> OffpolicyTrainer: sampling_config = self.sampling_config callbacks = self.trainer_callbacks - context = TrainingContext(policy, envs, logger) + context = TrainingContext(world.policy, world.envs, world.logger) train_fn = ( callbacks.epoch_callback_train.get_trainer_fn(context) if callbacks.epoch_callback_train @@ -243,16 +202,16 @@ class OffpolicyAgentFactory(AgentFactory, ABC): callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None ) return OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, + policy=world.policy, + train_collector=world.train_collector, + test_collector=world.test_collector, max_epoch=sampling_config.num_epochs, step_per_epoch=sampling_config.step_per_epoch, step_per_collect=sampling_config.step_per_collect, episode_per_test=sampling_config.num_test_envs, batch_size=sampling_config.batch_size, - save_best_fn=self._create_save_best_fn(envs, logger.log_path), - logger=logger.logger, + save_best_fn=policy_persistence.get_save_best_fn(world), + logger=world.logger.logger, update_per_step=sampling_config.update_per_step, test_in_train=False, train_fn=train_fn, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index ad242d6..a3cd264 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -39,6 +39,7 @@ from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.core import ( ImplicitQuantileNetworkFactory, IntermediateModuleFactory, + TDevice, ) from tianshou.highlevel.module.critic import ( CriticEnsembleFactory, @@ -63,15 +64,17 @@ from tianshou.highlevel.params.policy_params import ( TRPOParams, ) from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory -from tianshou.highlevel.persistence import PersistableConfigProtocol +from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence from tianshou.highlevel.trainer import ( TrainerCallbacks, TrainerEpochCallbackTest, TrainerEpochCallbackTrain, TrainerStopCallback, ) +from tianshou.highlevel.world import World from tianshou.policy import BasePolicy from tianshou.trainer import BaseTrainer +from tianshou.utils.logging import datetime_tag from tianshou.utils.string import ToStringMixin log = logging.getLogger(__name__) @@ -86,14 +89,18 @@ class ExperimentConfig: seed: int = 42 render: float | None = 0.0 """Milliseconds between rendered frames; if None, no rendering""" - device: str = "cuda" if torch.cuda.is_available() else "cpu" - resume_id: str | None = None - """For restoring a model and running means of env-specifics from a checkpoint""" - resume_path: str | None = None - """For restoring a model and running means of env-specifics from a checkpoint""" - watch: bool = False - """If True, will not perform training and only watch the restored policy""" + device: TDevice = "cuda" if torch.cuda.is_available() else "cpu" + """The torch device to use""" + policy_restore_directory: str | None = None + """Directory from which to load the policy neural network parameters (saved in a previous run)""" + train: bool = True + """Whether to perform training""" + watch: bool = True + """Whether to watch agent performance (after training)""" watch_num_episodes = 10 + """Number of episodes for which to watch performance (if watch is enabled)""" + watch_render: float = 0.0 + """Milliseconds between rendered frames when watching agent performance (if watch is enabled)""" class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): @@ -123,55 +130,71 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): # TODO } - def run(self, log_name: str) -> None: + def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None: + """:param experiment_name: the experiment name, which corresponds to the directory (within the logging + directory) where all results associated with the experiment will be saved. + The name may contain path separators (os.path.sep, used by os.path.join), in which case + a nested directory structure will be created. + If None, use a name containing the current date and time. + :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when + using wandb, in particular). + :return: + """ + if experiment_name is None: + experiment_name = datetime_tag() + self._set_seed() envs = self.env_factory(self.env_config) + policy_persistence = PolicyPersistence() log.info(f"Created {envs}") full_config = self._build_config_dict() full_config.update(envs.info()) - run_id = self.config.resume_id logger = self.logger_factory.create_logger( - log_name=log_name, - run_id=run_id, + log_name=experiment_name, + run_id=logger_run_id, config_dict=full_config, ) policy = self.agent_factory.create_policy(envs, self.config.device) - if self.config.resume_path: - self.agent_factory.load_checkpoint( - policy, - self.config.resume_path, - envs, - self.config.device, - ) train_collector, test_collector = self.agent_factory.create_train_test_collector( policy, envs, ) - if not self.config.watch: - trainer = self.agent_factory.create_trainer( + world = World( + envs=envs, + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + logger=logger, + ) + + if self.config.policy_restore_directory: + policy_persistence.restore( policy, - train_collector, - test_collector, - envs, - logger, + self.config.policy_restore_directory, + self.config.device, ) + + if self.config.train: + trainer = self.agent_factory.create_trainer(world, policy_persistence) + world.trainer = trainer + result = trainer.run() pprint(result) # TODO logging - render = self.config.render - if render is None: - render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode? - self._watch_agent( - self.config.watch_num_episodes, - policy, - test_collector, - render, - ) + if self.config.watch: + self._watch_agent( + self.config.watch_num_episodes, + policy, + test_collector, + self.config.watch_render, + ) + + # TODO return result @staticmethod def _watch_agent( diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 8bd7b5f..e02b280 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -120,7 +120,8 @@ class ActorFactoryDefault(ActorFactory): return factory.create_module(envs, device) elif env_type == EnvType.DISCRETE: factory = ActorFactoryDiscreteNet( - self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax, + self.DEFAULT_HIDDEN_SIZES, + softmax_output=self.discrete_softmax, ) return factory.create_module(envs, device) else: diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 71db019..8617c82 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -1,5 +1,16 @@ +import logging import os -from typing import Protocol, Self, runtime_checkable +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable + +import torch + +from tianshou.highlevel.world import World + +if TYPE_CHECKING: + from tianshou.highlevel.module.core import TDevice + +log = logging.getLogger(__name__) @runtime_checkable @@ -10,3 +21,50 @@ class PersistableConfigProtocol(Protocol): def save(self, path: os.PathLike[str]) -> None: pass + + +class Persistence(ABC): + def path(self, world: World, filename: str) -> str: + return os.path.join(world.directory, filename) + + @abstractmethod + def persist(self, world: World) -> None: + pass + + @abstractmethod + def restore(self, world: World): + pass + + +class PersistenceGroup(Persistence): + def __init__(self, *p: Persistence): + self.items = p + + def persist(self, world: World) -> None: + for item in self.items: + item.persist(world) + + def restore(self, world: World): + for item in self.items: + item.restore(world) + + +class PolicyPersistence: + FILENAME = "policy.dat" + + def persist(self, policy: torch.nn.Module, directory: str) -> None: + path = os.path.join(directory, self.FILENAME) + log.info(f"Saving policy in {path}") + torch.save(policy.state_dict(), path) + + def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None: + path = os.path.join(directory, self.FILENAME) + log.info(f"Restoring policy from {path}") + state_dict = torch.load(path, map_location=device) + policy.load_state_dict(state_dict) + + def get_save_best_fn(self, world): + def save_best_fn(pol: torch.nn.Module) -> None: + self.persist(pol, world.directory) + + return save_best_fn diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py new file mode 100644 index 0000000..5d69a07 --- /dev/null +++ b/tianshou/highlevel/world.py @@ -0,0 +1,27 @@ +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from tianshou.data import Collector + from tianshou.highlevel.env import Environments + from tianshou.highlevel.logger import Logger + from tianshou.policy import BasePolicy + from tianshou.trainer import BaseTrainer + + +@dataclass +class World: + envs: "Environments" + policy: "BasePolicy" + train_collector: "Collector" + test_collector: "Collector" + logger: "Logger" + trainer: Optional["BaseTrainer"] = None + + @property + def directory(self): + return self.logger.log_path + + def path(self, filename: str) -> str: + return os.path.join(self.directory, filename)