From 9f0a410bb13c3cb61ecba34c6d19f0bb3ad454b2 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 3 Oct 2023 21:14:22 +0200 Subject: [PATCH] Log full experiment configuration, adding string representations to relevant classes --- examples/atari/atari_ppo_hl.py | 3 ++- examples/mujoco/mujoco_a2c_hl.py | 3 ++- examples/mujoco/mujoco_ddpg_hl.py | 3 ++- examples/mujoco/mujoco_ppo_hl.py | 3 ++- examples/mujoco/mujoco_sac_hl.py | 3 ++- examples/mujoco/mujoco_td3_hl.py | 3 ++- tianshou/highlevel/agent.py | 3 ++- tianshou/highlevel/experiment.py | 9 +++++++-- tianshou/highlevel/logger.py | 3 ++- tianshou/highlevel/module/actor.py | 3 ++- tianshou/highlevel/module/critic.py | 3 ++- tianshou/highlevel/module/module_opt.py | 5 +++-- tianshou/highlevel/optim.py | 4 +++- 13 files changed, 33 insertions(+), 15 deletions(-) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 12939e9..12733bb 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -21,6 +21,7 @@ from tianshou.highlevel.params.policy_params import PPOParams from tianshou.highlevel.params.policy_wrapper import ( PolicyWrapperFactoryIntrinsicCuriosity, ) +from tianshou.utils import logging def main( @@ -114,4 +115,4 @@ def main( if __name__ == "__main__": - CLI(main) + logging.run_main(lambda: CLI(main)) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 2d6a583..a6f48f6 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import ( from tianshou.highlevel.optim import OptimizerFactoryRMSprop from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import A2CParams +from tianshou.utils import logging def main( @@ -82,4 +83,4 @@ def main( if __name__ == "__main__": - CLI(main) + logging.run_main(lambda: CLI(main)) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 097be5d..0b173b4 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import ( ) from tianshou.highlevel.params.noise import MaxActionScaledGaussian from tianshou.highlevel.params.policy_params import DDPGParams +from tianshou.utils import logging def main( @@ -75,4 +76,4 @@ def main( if __name__ == "__main__": - CLI(main) + logging.run_main(lambda: CLI(main)) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 814792b..3272b0f 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import ( ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams +from tianshou.utils import logging def main( @@ -95,4 +96,4 @@ def main( if __name__ == "__main__": - CLI(main) + logging.run_main(lambda: CLI(main)) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index f4bf327..1996689 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import ( ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.policy_params import SACParams +from tianshou.utils import logging def main( @@ -81,4 +82,4 @@ def main( if __name__ == "__main__": - CLI(main) + logging.run_main(lambda: CLI(main)) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 9e0b33b..f2c906a 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -17,6 +17,7 @@ from tianshou.highlevel.params.noise import ( MaxActionScaledGaussian, ) from tianshou.highlevel.params.policy_params import TD3Params +from tianshou.utils import logging def main( @@ -84,4 +85,4 @@ def main( if __name__ == "__main__": - CLI(main) + logging.run_main(lambda: CLI(main)) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 87163cc..b97f5ed 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -41,6 +41,7 @@ from tianshou.policy import ( from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import ActorCritic +from tianshou.utils.string import ToStringMixin CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" @@ -48,7 +49,7 @@ TParams = TypeVar("TParams", bound=Params) TPolicy = TypeVar("TPolicy", bound=BasePolicy) -class AgentFactory(ABC): +class AgentFactory(ABC, ToStringMixin): def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory): self.sampling_config = sampling_config self.optim_factory = optim_factory diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 4113cdd..3d86203 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,3 +1,4 @@ +import logging from abc import abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass @@ -37,7 +38,9 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.policy import BasePolicy from tianshou.trainer import BaseTrainer +from tianshou.utils.string import ToStringMixin +log = logging.getLogger(__name__) TPolicy = TypeVar("TPolicy", bound=BasePolicy) TTrainer = TypeVar("TTrainer", bound=BaseTrainer) @@ -59,7 +62,7 @@ class RLExperimentConfig: watch_num_episodes = 10 -class RLExperiment(Generic[TPolicy, TTrainer]): +class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin): def __init__( self, config: RLExperimentConfig, @@ -204,13 +207,15 @@ class RLExperimentBuilder: agent_factory = self._create_agent_factory() if self._policy_wrapper_factory: agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) - return RLExperiment( + experiment = RLExperiment( self._config, self._env_factory, agent_factory, self._logger_factory, env_config=self._env_config, ) + log.info(f"Created experiment:\n{experiment.pprints()}") + return experiment class _BuilderMixinActorFactory: diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index c556f4f..8de2033 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -6,6 +6,7 @@ from typing import Literal, TypeAlias from torch.utils.tensorboard import SummaryWriter from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils.string import ToStringMixin TLogger: TypeAlias = TensorboardLogger | WandbLogger @@ -16,7 +17,7 @@ class Logger: log_path: str -class LoggerFactory(ABC): +class LoggerFactory(ToStringMixin, ABC): @abstractmethod def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger: pass diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 488928f..8849f38 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -8,6 +8,7 @@ from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import BaseActor, Net +from tianshou.utils.string import ToStringMixin class ContinuousActorType: @@ -15,7 +16,7 @@ class ContinuousActorType: DETERMINISTIC = "deterministic" -class ActorFactory(ABC): +class ActorFactory(ToStringMixin, ABC): @abstractmethod def create_module(self, envs: Environments, device: TDevice) -> BaseActor: pass diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index f15adbe..092c172 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -7,9 +7,10 @@ from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.utils.net import continuous from tianshou.utils.net.common import Net +from tianshou.utils.string import ToStringMixin -class CriticFactory(ABC): +class CriticFactory(ToStringMixin, ABC): @abstractmethod def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: pass diff --git a/tianshou/highlevel/module/module_opt.py b/tianshou/highlevel/module/module_opt.py index 8feff85..802d305 100644 --- a/tianshou/highlevel/module/module_opt.py +++ b/tianshou/highlevel/module/module_opt.py @@ -8,6 +8,7 @@ from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.critic import CriticFactory from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net.common import ActorCritic +from tianshou.utils.string import ToStringMixin @dataclass @@ -30,7 +31,7 @@ class ActorCriticModuleOpt: return self.actor_critic_module.critic -class ActorModuleOptFactory: +class ActorModuleOptFactory(ToStringMixin): def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory): self.actor_factory = actor_factory self.optim_factory = optim_factory @@ -41,7 +42,7 @@ class ActorModuleOptFactory: return ModuleOpt(actor, opt) -class CriticModuleOptFactory: +class CriticModuleOptFactory(ToStringMixin): def __init__( self, critic_factory: CriticFactory, diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index ef3071d..de5c434 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -4,8 +4,10 @@ from typing import Any import torch from torch.optim import Adam, RMSprop +from tianshou.utils.string import ToStringMixin -class OptimizerFactory(ABC): + +class OptimizerFactory(ABC, ToStringMixin): # TODO: Is it OK to assume that all optimizers have a learning rate argument? # Right now, the learning rate is typically a configuration parameter. # If we drop the assumption, we can't have that and will need to move the parameter