From 49781e715eb4890aec572c0ce436f35ed086df97 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 23 Feb 2024 23:17:14 +0100 Subject: [PATCH] Fix high-level examples (#1060) The high-level examples were all broken by changes made to make mypy pass. This PR fixes them, making a type change in logging.run_cli instead to make mypy happy. --- examples/atari/atari_dqn_hl.py | 4 +--- examples/atari/atari_iqn_hl.py | 4 +--- examples/atari/atari_ppo_hl.py | 4 +--- examples/atari/atari_sac_hl.py | 4 +--- examples/mujoco/mujoco_a2c_hl.py | 4 +--- examples/mujoco/mujoco_ddpg_hl.py | 4 +--- examples/mujoco/mujoco_npg_hl.py | 4 +--- examples/mujoco/mujoco_ppo_hl.py | 4 +--- examples/mujoco/mujoco_redq_hl.py | 4 +--- examples/mujoco/mujoco_reinforce_hl.py | 4 +--- examples/mujoco/mujoco_sac_hl.py | 4 +--- examples/mujoco/mujoco_td3_hl.py | 4 +--- examples/mujoco/mujoco_trpo_hl.py | 4 +--- tianshou/utils/logging.py | 2 +- 14 files changed, 14 insertions(+), 40 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index e7a1003..887ebc8 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from examples.atari.atari_network import ( @@ -103,5 +102,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 32492a3..dcdacf2 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence @@ -95,5 +94,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 736fb1d..b492b9c 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence @@ -114,5 +113,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index f2e0277..1271567 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from examples.atari.atari_network import ( @@ -101,5 +100,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 5b29a75..fec2e26 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence from typing import Literal @@ -83,5 +82,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 0026acf..2bbc669 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence @@ -74,5 +73,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 85451be..6ab0eb8 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence from typing import Literal @@ -85,5 +84,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 7ff4801..dbc6fb5 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence from typing import Literal @@ -95,5 +94,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 8b8234b..78e0d34 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence from typing import Literal @@ -83,5 +82,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 1497615..bc07e05 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence from typing import Literal @@ -72,5 +71,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index cc22a30..c6a6a3b 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence @@ -80,5 +79,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index e6ab40d..73d20fe 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence @@ -85,5 +84,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 528e974..2f9a777 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import functools import os from collections.abc import Sequence from typing import Literal @@ -89,5 +88,4 @@ def main( if __name__ == "__main__": - run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) - logging.run_cli(run_with_default_config) + logging.run_cli(main) diff --git a/tianshou/utils/logging.py b/tianshou/utils/logging.py index dcda429..b2eaf3f 100644 --- a/tianshou/utils/logging.py +++ b/tianshou/utils/logging.py @@ -100,7 +100,7 @@ def run_main( def run_cli( - main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG + main_fn: Callable[..., T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG ) -> T | None: """ Configures logging with the given parameters and runs the given main function as a