Move RLSamplingConfig to separate module config, fixing cyclic import

This commit is contained in:
Dominik Jain 2023-09-20 15:28:33 +02:00
parent d26b8cb40c
commit 8ec42009cb
7 changed files with 23 additions and 22 deletions

View File

@ -4,7 +4,7 @@ import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.experiment import RLSamplingConfig
from tianshou.highlevel.config import RLSamplingConfig
try:
import envpool

View File

@ -13,8 +13,8 @@ from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAge
from tianshou.highlevel.experiment import (
RLExperiment,
RLExperimentConfig,
RLSamplingConfig,
)
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,

View File

@ -11,8 +11,8 @@ from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, S
from tianshou.highlevel.experiment import (
RLExperiment,
RLExperimentConfig,
RLSamplingConfig,
)
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,

View File

@ -10,7 +10,7 @@ import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.experiment import RLSamplingConfig
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory

View File

@ -0,0 +1,18 @@
from dataclasses import dataclass
@dataclass
class RLSamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
num_epochs: int = 100
step_per_epoch: int = 30000
batch_size: int = 64
num_train_envs: int = 64
num_test_envs: int = 10
buffer_size: int = 4096
step_per_collect: int = 2048
repeat_per_collect: int = 10
update_per_step: int = 1
start_timesteps: int = 0
start_timesteps_random: bool = False

View File

@ -33,23 +33,6 @@ class RLExperimentConfig:
watch_num_episodes = 10
@dataclass
class RLSamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
num_epochs: int = 100
step_per_epoch: int = 30000
batch_size: int = 64
num_train_envs: int = 64
num_test_envs: int = 10
buffer_size: int = 4096
step_per_collect: int = 2048
repeat_per_collect: int = 10
update_per_step: int = 1
start_timesteps: int = 0
start_timesteps_random: bool = False
class RLExperiment(Generic[TPolicy, TTrainer]):
def __init__(
self,

View File

@ -8,7 +8,7 @@ from torch import Tensor
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.experiment import RLSamplingConfig
from tianshou.highlevel.config import RLSamplingConfig
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]