From 7437131d791cdd59e6370400ff83bd1616eee08f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 19 Oct 2023 11:40:49 +0200 Subject: [PATCH] Fix tianshou.highlevel depending on jsonargparse (should be dev dependency only) by introducing a new place where jsonargparse can be configured: logging.run_cli, which is also slightly more convenient --- examples/atari/atari_dqn_hl.py | 4 +--- examples/atari/atari_iqn_hl.py | 4 +--- examples/atari/atari_ppo_hl.py | 4 +--- examples/atari/atari_sac_hl.py | 4 +--- examples/mujoco/mujoco_a2c_hl.py | 3 +-- examples/mujoco/mujoco_ddpg_hl.py | 4 +--- examples/mujoco/mujoco_npg_hl.py | 3 +-- examples/mujoco/mujoco_ppo_hl.py | 3 +-- examples/mujoco/mujoco_redq_hl.py | 4 +--- examples/mujoco/mujoco_reinforce_hl.py | 3 +-- examples/mujoco/mujoco_sac_hl.py | 4 +--- examples/mujoco/mujoco_td3_hl.py | 3 +-- examples/mujoco/mujoco_trpo_hl.py | 3 +-- tianshou/highlevel/__init__.py | 2 -- tianshou/utils/logging.py | 20 ++++++++++++++++++++ 15 files changed, 33 insertions(+), 35 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index c38d30a..be981b5 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -2,8 +2,6 @@ import os -from jsonargparse import CLI - from examples.atari.atari_callbacks import ( TestEpochCallbackDQNSetEps, TrainEpochCallbackNatureDQNEpsLinearDecay, @@ -104,4 +102,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index b0119b0..2cd709c 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -3,8 +3,6 @@ import os from collections.abc import Sequence -from jsonargparse import CLI - from examples.atari.atari_callbacks import ( TestEpochCallbackDQNSetEps, TrainEpochCallbackNatureDQNEpsLinearDecay, @@ -96,4 +94,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index b039616..1c1f1ad 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -4,8 +4,6 @@ import datetime import os from collections.abc import Sequence -from jsonargparse import CLI - from examples.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, @@ -116,4 +114,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index d0fd067..c602550 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -2,8 +2,6 @@ import os -from jsonargparse import CLI - from examples.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, @@ -102,4 +100,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 6982b94..1825b6c 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -4,7 +4,6 @@ import os from collections.abc import Sequence from typing import Literal -from jsonargparse import CLI from torch import nn from examples.mujoco.mujoco_env import MujocoEnvFactory @@ -83,4 +82,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index e7c519e..3e3fc7a 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -3,8 +3,6 @@ import os from collections.abc import Sequence -from jsonargparse import CLI - from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( @@ -75,4 +73,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index b6d2637..8d496f2 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -5,7 +5,6 @@ from collections.abc import Sequence from typing import Literal import torch -from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -85,4 +84,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index ba81806..f4110ae 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -5,7 +5,6 @@ from collections.abc import Sequence from typing import Literal import torch -from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -95,4 +94,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 2136f33..66219bf 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -4,8 +4,6 @@ import os from collections.abc import Sequence from typing import Literal -from jsonargparse import CLI - from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( @@ -84,4 +82,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 5097d98..1a08449 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -5,7 +5,6 @@ from collections.abc import Sequence from typing import Literal import torch -from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -72,4 +71,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 2edca7b..e94b8ab 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -4,8 +4,6 @@ import datetime import os from collections.abc import Sequence -from jsonargparse import CLI - from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( @@ -82,4 +80,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index e4021d6..1f783f1 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -4,7 +4,6 @@ import os from collections.abc import Sequence import torch -from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -85,4 +84,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 34ed14b..5901cc5 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -5,7 +5,6 @@ from collections.abc import Sequence from typing import Literal import torch -from jsonargparse import CLI from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -89,4 +88,4 @@ def main( if __name__ == "__main__": - logging.run_main(lambda: CLI(main)) + logging.run_cli(main) diff --git a/tianshou/highlevel/__init__.py b/tianshou/highlevel/__init__.py index 8ce1671..8b13789 100644 --- a/tianshou/highlevel/__init__.py +++ b/tianshou/highlevel/__init__.py @@ -1,3 +1 @@ -from jsonargparse import set_docstring_parse_options -set_docstring_parse_options(attribute_docstrings=True) diff --git a/tianshou/utils/logging.py b/tianshou/utils/logging.py index fbf8817..caf0ae9 100644 --- a/tianshou/utils/logging.py +++ b/tianshou/utils/logging.py @@ -74,6 +74,26 @@ def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEB log.error("Exception during script execution", exc_info=e) +def run_cli(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG): + """ + Configures logging with the given parameters and runs the given main function as a + CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such + that dataclasses can be used as function arguments). + Using this function requires that `jsonargparse` and `docstring_parser` be available. + Like `run_main`, two additional log messages will be logged (at the beginning and end + of the execution), and it is ensured that all exceptions will be logged. + + :param main_fn: the function to be executed + :param format: the log message format + :param level: the minimum log level + :return: the result of `main_fn` + """ + from jsonargparse import set_docstring_parse_options, CLI + + set_docstring_parse_options(attribute_docstrings=True) + return run_main(lambda: CLI(main_fn), format=format, level=level) + + def datetime_tag() -> str: """:return: a string tag for use in log file names which contains the current date and time (compact but readable)""" return datetime.now().strftime("%Y%m%d-%H%M%S")