Fix tianshou.highlevel depending on jsonargparse

(should be dev dependency only) by introducing a new
place where jsonargparse can be configured:
logging.run_cli, which is also slightly more convenient
This commit is contained in:
Dominik Jain 2023-10-19 11:40:49 +02:00
parent 6cbee188b8
commit 7437131d79
15 changed files with 33 additions and 35 deletions

View File

@ -2,8 +2,6 @@
import os
from jsonargparse import CLI
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
@ -104,4 +102,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -3,8 +3,6 @@
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
@ -96,4 +94,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -4,8 +4,6 @@ import datetime
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
@ -116,4 +114,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -2,8 +2,6 @@
import os
from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
@ -102,4 +100,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -4,7 +4,6 @@ import os
from collections.abc import Sequence
from typing import Literal
from jsonargparse import CLI
from torch import nn
from examples.mujoco.mujoco_env import MujocoEnvFactory
@ -83,4 +82,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -3,8 +3,6 @@
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
@ -75,4 +73,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from typing import Literal
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
@ -85,4 +84,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from typing import Literal
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
@ -95,4 +94,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -4,8 +4,6 @@ import os
from collections.abc import Sequence
from typing import Literal
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
@ -84,4 +82,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from typing import Literal
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
@ -72,4 +71,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -4,8 +4,6 @@ import datetime
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
@ -82,4 +80,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -4,7 +4,6 @@ import os
from collections.abc import Sequence
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
@ -85,4 +84,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from typing import Literal
import torch
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
@ -89,4 +88,4 @@ def main(
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))
logging.run_cli(main)

View File

@ -1,3 +1 @@
from jsonargparse import set_docstring_parse_options
set_docstring_parse_options(attribute_docstrings=True)

View File

@ -74,6 +74,26 @@ def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEB
log.error("Exception during script execution", exc_info=e)
def run_cli(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
"""
Configures logging with the given parameters and runs the given main function as a
CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such
that dataclasses can be used as function arguments).
Using this function requires that `jsonargparse` and `docstring_parser` be available.
Like `run_main`, two additional log messages will be logged (at the beginning and end
of the execution), and it is ensured that all exceptions will be logged.
:param main_fn: the function to be executed
:param format: the log message format
:param level: the minimum log level
:return: the result of `main_fn`
"""
from jsonargparse import set_docstring_parse_options, CLI
set_docstring_parse_options(attribute_docstrings=True)
return run_main(lambda: CLI(main_fn), format=format, level=level)
def datetime_tag() -> str:
""":return: a string tag for use in log file names which contains the current date and time (compact but readable)"""
return datetime.now().strftime("%Y%m%d-%H%M%S")