Fix tianshou.highlevel depending on jsonargparse

(should be dev dependency only) by introducing a new
place where jsonargparse can be configured:
logging.run_cli, which is also slightly more convenient
This commit is contained in:
Dominik Jain 2023-10-19 11:40:49 +02:00
parent 6cbee188b8
commit 7437131d79
15 changed files with 33 additions and 35 deletions

View File

@ -2,8 +2,6 @@
import os import os
from jsonargparse import CLI
from examples.atari.atari_callbacks import ( from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps, TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay, TrainEpochCallbackNatureDQNEpsLinearDecay,
@ -104,4 +102,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -3,8 +3,6 @@
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from jsonargparse import CLI
from examples.atari.atari_callbacks import ( from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps, TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay, TrainEpochCallbackNatureDQNEpsLinearDecay,
@ -96,4 +94,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -4,8 +4,6 @@ import datetime
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
ActorFactoryAtariDQN, ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures, IntermediateModuleFactoryAtariDQNFeatures,
@ -116,4 +114,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -2,8 +2,6 @@
import os import os
from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
ActorFactoryAtariDQN, ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures, IntermediateModuleFactoryAtariDQNFeatures,
@ -102,4 +100,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -4,7 +4,6 @@ import os
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal from typing import Literal
from jsonargparse import CLI
from torch import nn from torch import nn
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -83,4 +82,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -3,8 +3,6 @@
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
@ -75,4 +73,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from typing import Literal from typing import Literal
import torch import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -85,4 +84,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from typing import Literal from typing import Literal
import torch import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -95,4 +94,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -4,8 +4,6 @@ import os
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal from typing import Literal
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
@ -84,4 +82,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from typing import Literal from typing import Literal
import torch import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -72,4 +71,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -4,8 +4,6 @@ import datetime
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
@ -82,4 +80,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -4,7 +4,6 @@ import os
from collections.abc import Sequence from collections.abc import Sequence
import torch import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -85,4 +84,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from typing import Literal from typing import Literal
import torch import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -89,4 +88,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
logging.run_main(lambda: CLI(main)) logging.run_cli(main)

View File

@ -1,3 +1 @@
from jsonargparse import set_docstring_parse_options
set_docstring_parse_options(attribute_docstrings=True)

View File

@ -74,6 +74,26 @@ def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEB
log.error("Exception during script execution", exc_info=e) log.error("Exception during script execution", exc_info=e)
def run_cli(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
"""
Configures logging with the given parameters and runs the given main function as a
CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such
that dataclasses can be used as function arguments).
Using this function requires that `jsonargparse` and `docstring_parser` be available.
Like `run_main`, two additional log messages will be logged (at the beginning and end
of the execution), and it is ensured that all exceptions will be logged.
:param main_fn: the function to be executed
:param format: the log message format
:param level: the minimum log level
:return: the result of `main_fn`
"""
from jsonargparse import set_docstring_parse_options, CLI
set_docstring_parse_options(attribute_docstrings=True)
return run_main(lambda: CLI(main_fn), format=format, level=level)
def datetime_tag() -> str: def datetime_tag() -> str:
""":return: a string tag for use in log file names which contains the current date and time (compact but readable)""" """:return: a string tag for use in log file names which contains the current date and time (compact but readable)"""
return datetime.now().strftime("%Y%m%d-%H%M%S") return datetime.now().strftime("%Y%m%d-%H%M%S")