Move RLSamplingConfig to separate module config, fixing cyclic import
This commit is contained in:
parent
d26b8cb40c
commit
8ec42009cb
@ -4,7 +4,7 @@ import gymnasium as gym
|
||||
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||
from tianshou.highlevel.experiment import RLSamplingConfig
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
|
||||
try:
|
||||
import envpool
|
||||
|
@ -13,8 +13,8 @@ from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAge
|
||||
from tianshou.highlevel.experiment import (
|
||||
RLExperiment,
|
||||
RLExperimentConfig,
|
||||
RLSamplingConfig,
|
||||
)
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ContinuousActorProbFactory,
|
||||
|
@ -11,8 +11,8 @@ from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, S
|
||||
from tianshou.highlevel.experiment import (
|
||||
RLExperiment,
|
||||
RLExperimentConfig,
|
||||
RLSamplingConfig,
|
||||
)
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ContinuousActorProbFactory,
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.experiment import RLSamplingConfig
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
||||
|
18
tianshou/highlevel/config.py
Normal file
18
tianshou/highlevel/config.py
Normal file
@ -0,0 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLSamplingConfig:
|
||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||
|
||||
num_epochs: int = 100
|
||||
step_per_epoch: int = 30000
|
||||
batch_size: int = 64
|
||||
num_train_envs: int = 64
|
||||
num_test_envs: int = 10
|
||||
buffer_size: int = 4096
|
||||
step_per_collect: int = 2048
|
||||
repeat_per_collect: int = 10
|
||||
update_per_step: int = 1
|
||||
start_timesteps: int = 0
|
||||
start_timesteps_random: bool = False
|
@ -33,23 +33,6 @@ class RLExperimentConfig:
|
||||
watch_num_episodes = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLSamplingConfig:
|
||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||
|
||||
num_epochs: int = 100
|
||||
step_per_epoch: int = 30000
|
||||
batch_size: int = 64
|
||||
num_train_envs: int = 64
|
||||
num_test_envs: int = 10
|
||||
buffer_size: int = 4096
|
||||
step_per_collect: int = 2048
|
||||
repeat_per_collect: int = 10
|
||||
update_per_step: int = 1
|
||||
start_timesteps: int = 0
|
||||
start_timesteps_random: bool = False
|
||||
|
||||
|
||||
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -8,7 +8,7 @@ from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from tianshou.highlevel.experiment import RLSamplingConfig
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
|
||||
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user