Move RLSamplingConfig to separate module config, fixing cyclic import
This commit is contained in:
parent
d26b8cb40c
commit
8ec42009cb
@ -4,7 +4,7 @@ import gymnasium as gym
|
|||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||||
from tianshou.highlevel.experiment import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
|
@ -13,8 +13,8 @@ from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAge
|
|||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperiment,
|
RLExperiment,
|
||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
RLSamplingConfig,
|
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ContinuousActorProbFactory,
|
ContinuousActorProbFactory,
|
||||||
|
@ -11,8 +11,8 @@ from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, S
|
|||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperiment,
|
RLExperiment,
|
||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
RLSamplingConfig,
|
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ContinuousActorProbFactory,
|
ContinuousActorProbFactory,
|
||||||
|
@ -10,7 +10,7 @@ import torch
|
|||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.experiment import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import Logger
|
||||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||||
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
||||||
|
18
tianshou/highlevel/config.py
Normal file
18
tianshou/highlevel/config.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RLSamplingConfig:
|
||||||
|
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||||
|
|
||||||
|
num_epochs: int = 100
|
||||||
|
step_per_epoch: int = 30000
|
||||||
|
batch_size: int = 64
|
||||||
|
num_train_envs: int = 64
|
||||||
|
num_test_envs: int = 10
|
||||||
|
buffer_size: int = 4096
|
||||||
|
step_per_collect: int = 2048
|
||||||
|
repeat_per_collect: int = 10
|
||||||
|
update_per_step: int = 1
|
||||||
|
start_timesteps: int = 0
|
||||||
|
start_timesteps_random: bool = False
|
@ -33,23 +33,6 @@ class RLExperimentConfig:
|
|||||||
watch_num_episodes = 10
|
watch_num_episodes = 10
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RLSamplingConfig:
|
|
||||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
|
||||||
|
|
||||||
num_epochs: int = 100
|
|
||||||
step_per_epoch: int = 30000
|
|
||||||
batch_size: int = 64
|
|
||||||
num_train_envs: int = 64
|
|
||||||
num_test_envs: int = 10
|
|
||||||
buffer_size: int = 4096
|
|
||||||
step_per_collect: int = 2048
|
|
||||||
repeat_per_collect: int = 10
|
|
||||||
update_per_step: int = 1
|
|
||||||
start_timesteps: int = 0
|
|
||||||
start_timesteps_random: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class RLExperiment(Generic[TPolicy, TTrainer]):
|
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -8,7 +8,7 @@ from torch import Tensor
|
|||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
|
|
||||||
from tianshou.highlevel.experiment import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
|
||||||
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user