From 8ec42009cb3d2a4ad6a89adfcf03f1a37cb59ffe Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 20 Sep 2023 15:28:33 +0200 Subject: [PATCH] Move RLSamplingConfig to separate module config, fixing cyclic import --- examples/mujoco/mujoco_env.py | 2 +- examples/mujoco/mujoco_ppo_hl.py | 2 +- examples/mujoco/mujoco_sac_hl.py | 2 +- tianshou/highlevel/agent.py | 2 +- tianshou/highlevel/config.py | 18 ++++++++++++++++++ tianshou/highlevel/experiment.py | 17 ----------------- tianshou/highlevel/optim.py | 2 +- 7 files changed, 23 insertions(+), 22 deletions(-) create mode 100644 tianshou/highlevel/config.py diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index b4e4d44..886cd78 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -4,7 +4,7 @@ import gymnasium as gym from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory -from tianshou.highlevel.experiment import RLSamplingConfig +from tianshou.highlevel.config import RLSamplingConfig try: import envpool diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 71a0b6d..3cedac9 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -13,8 +13,8 @@ from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAge from tianshou.highlevel.experiment import ( RLExperiment, RLExperimentConfig, - RLSamplingConfig, ) +from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.logger import DefaultLoggerFactory from tianshou.highlevel.module import ( ContinuousActorProbFactory, diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 9be572d..37f1d6e 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -11,8 +11,8 @@ from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, S from tianshou.highlevel.experiment import ( RLExperiment, RLExperimentConfig, - RLSamplingConfig, ) +from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.logger import DefaultLoggerFactory from tianshou.highlevel.module import ( ContinuousActorProbFactory, diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index ce7545f..9cd6140 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -10,7 +10,7 @@ import torch from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import BaseNoise from tianshou.highlevel.env import Environments -from tianshou.highlevel.experiment import RLSamplingConfig +from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.logger import Logger from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py new file mode 100644 index 0000000..b188f36 --- /dev/null +++ b/tianshou/highlevel/config.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + + +@dataclass +class RLSamplingConfig: + """Sampling, epochs, parallelization, buffers, collectors, and batching.""" + + num_epochs: int = 100 + step_per_epoch: int = 30000 + batch_size: int = 64 + num_train_envs: int = 64 + num_test_envs: int = 10 + buffer_size: int = 4096 + step_per_collect: int = 2048 + repeat_per_collect: int = 10 + update_per_step: int = 1 + start_timesteps: int = 0 + start_timesteps_random: bool = False diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 76c957b..1012028 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -33,23 +33,6 @@ class RLExperimentConfig: watch_num_episodes = 10 -@dataclass -class RLSamplingConfig: - """Sampling, epochs, parallelization, buffers, collectors, and batching.""" - - num_epochs: int = 100 - step_per_epoch: int = 30000 - batch_size: int = 64 - num_train_envs: int = 64 - num_test_envs: int = 10 - buffer_size: int = 4096 - step_per_collect: int = 2048 - repeat_per_collect: int = 10 - update_per_step: int = 1 - start_timesteps: int = 0 - start_timesteps_random: bool = False - - class RLExperiment(Generic[TPolicy, TTrainer]): def __init__( self, diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 4e68104..edbe0a4 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -8,7 +8,7 @@ from torch import Tensor from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from tianshou.highlevel.experiment import RLSamplingConfig +from tianshou.highlevel.config import RLSamplingConfig TParams = Iterable[Tensor] | Iterable[dict[str, Any]]