Remove 'RL' prefix from class names

This commit is contained in:
Dominik Jain 2023-10-06 13:50:23 +02:00
parent 50ac385321
commit d269063e6a
15 changed files with 79 additions and 79 deletions

View File

@ -10,10 +10,10 @@ from examples.atari.atari_network import (
FeatureNetFactoryDQN, FeatureNetFactoryDQN,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
DQNExperimentBuilder, DQNExperimentBuilder,
RLExperimentConfig, ExperimentConfig,
) )
from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import ( from tianshou.highlevel.params.policy_wrapper import (
@ -25,7 +25,7 @@ from tianshou.utils import logging
def main( def main(
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4", task: str = "PongNoFrameskip-v4",
scale_obs: int = 0, scale_obs: int = 0,
eps_test: float = 0.005, eps_test: float = 0.005,
@ -52,7 +52,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
batch_size=batch_size, batch_size=batch_size,

View File

@ -11,10 +11,10 @@ from examples.atari.atari_network import (
FeatureNetFactoryDQN, FeatureNetFactoryDQN,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
PPOExperimentBuilder, PPOExperimentBuilder,
RLExperimentConfig, ExperimentConfig,
) )
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams from tianshou.highlevel.params.policy_params import PPOParams
@ -25,7 +25,7 @@ from tianshou.utils import logging
def main( def main(
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4", task: str = "PongNoFrameskip-v4",
scale_obs: bool = True, scale_obs: bool = True,
buffer_size: int = 100000, buffer_size: int = 100000,
@ -59,7 +59,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
batch_size=batch_size, batch_size=batch_size,

View File

@ -9,7 +9,7 @@ import gymnasium as gym
import numpy as np import numpy as np
from tianshou.env import ShmemVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
@ -379,7 +379,7 @@ class AtariEnvFactory(EnvFactory):
self, self,
task: str, task: str,
seed: int, seed: int,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
frame_stack: int, frame_stack: int,
scale: int = 0, scale: int = 0,
): ):

View File

@ -8,10 +8,10 @@ from typing import Literal
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
A2CExperimentBuilder, A2CExperimentBuilder,
RLExperimentConfig, ExperimentConfig,
) )
from tianshou.highlevel.optim import OptimizerFactoryRMSprop from tianshou.highlevel.optim import OptimizerFactoryRMSprop
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
@ -20,7 +20,7 @@ from tianshou.utils import logging
def main( def main(
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
task: str = "Ant-v3", task: str = "Ant-v3",
buffer_size: int = 4096, buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64), hidden_sizes: Sequence[int] = (64, 64),
@ -44,7 +44,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
batch_size=batch_size, batch_size=batch_size,

View File

@ -7,10 +7,10 @@ from collections.abc import Sequence
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
DDPGExperimentBuilder, DDPGExperimentBuilder,
RLExperimentConfig, ExperimentConfig,
) )
from tianshou.highlevel.params.noise import MaxActionScaledGaussian from tianshou.highlevel.params.noise import MaxActionScaledGaussian
from tianshou.highlevel.params.policy_params import DDPGParams from tianshou.highlevel.params.policy_params import DDPGParams
@ -18,7 +18,7 @@ from tianshou.utils import logging
def main( def main(
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
task: str = "Ant-v3", task: str = "Ant-v3",
buffer_size: int = 1000000, buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256), hidden_sizes: Sequence[int] = (256, 256),
@ -40,7 +40,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
batch_size=batch_size, batch_size=batch_size,

View File

@ -3,7 +3,7 @@ import warnings
import gymnasium as gym import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
try: try:
@ -41,7 +41,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
class MujocoEnvFactory(EnvFactory): class MujocoEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig): def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
self.task = task self.task = task
self.sampling_config = sampling_config self.sampling_config = sampling_config
self.seed = seed self.seed = seed

View File

@ -9,10 +9,10 @@ from jsonargparse import CLI
from torch.distributions import Independent, Normal from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
PPOExperimentBuilder, PPOExperimentBuilder,
RLExperimentConfig, ExperimentConfig,
) )
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams from tianshou.highlevel.params.policy_params import PPOParams
@ -20,7 +20,7 @@ from tianshou.utils import logging
def main( def main(
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
task: str = "Ant-v4", task: str = "Ant-v4",
buffer_size: int = 4096, buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64), hidden_sizes: Sequence[int] = (64, 64),
@ -49,7 +49,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
batch_size=batch_size, batch_size=batch_size,

View File

@ -7,9 +7,9 @@ from collections.abc import Sequence
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
RLExperimentConfig, ExperimentConfig,
SACExperimentBuilder, SACExperimentBuilder,
) )
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
@ -18,7 +18,7 @@ from tianshou.utils import logging
def main( def main(
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
task: str = "Ant-v3", task: str = "Ant-v3",
buffer_size: int = 1000000, buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256), hidden_sizes: Sequence[int] = (256, 256),
@ -42,7 +42,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "sac", str(experiment_config.seed), now) log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
num_train_envs=training_num, num_train_envs=training_num,

View File

@ -7,9 +7,9 @@ from collections.abc import Sequence
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
RLExperimentConfig, ExperimentConfig,
TD3ExperimentBuilder, TD3ExperimentBuilder,
) )
from tianshou.highlevel.params.env_param import MaxActionScaled from tianshou.highlevel.params.env_param import MaxActionScaled
@ -21,7 +21,7 @@ from tianshou.utils import logging
def main( def main(
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
task: str = "Ant-v3", task: str = "Ant-v3",
buffer_size: int = 1000000, buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256), hidden_sizes: Sequence[int] = (256, 256),
@ -46,7 +46,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "td3", str(experiment_config.seed), now) log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
num_train_envs=training_num, num_train_envs=training_num,

View File

@ -2,12 +2,12 @@ from test.highlevel.env_factory import ContinuousTestEnvFactory
import pytest import pytest
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
A2CExperimentBuilder, A2CExperimentBuilder,
DDPGExperimentBuilder, DDPGExperimentBuilder,
PPOExperimentBuilder, PPOExperimentBuilder,
RLExperimentConfig, ExperimentConfig,
SACExperimentBuilder, SACExperimentBuilder,
TD3ExperimentBuilder, TD3ExperimentBuilder,
) )
@ -25,8 +25,8 @@ from tianshou.highlevel.experiment import (
) )
def test_experiment_builder_continuous_default_params(builder_cls): def test_experiment_builder_continuous_default_params(builder_cls):
env_factory = ContinuousTestEnvFactory() env_factory = ContinuousTestEnvFactory()
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100) sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
experiment_config = RLExperimentConfig() experiment_config = ExperimentConfig()
builder = builder_cls( builder = builder_cls(
experiment_config=experiment_config, experiment_config=experiment_config,
env_factory=env_factory, env_factory=env_factory,

View File

@ -2,12 +2,12 @@ from test.highlevel.env_factory import DiscreteTestEnvFactory
import pytest import pytest
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
A2CExperimentBuilder, A2CExperimentBuilder,
DQNExperimentBuilder, DQNExperimentBuilder,
PPOExperimentBuilder, PPOExperimentBuilder,
RLExperimentConfig, ExperimentConfig,
) )
@ -17,9 +17,9 @@ from tianshou.highlevel.experiment import (
) )
def test_experiment_builder_discrete_default_params(builder_cls): def test_experiment_builder_discrete_default_params(builder_cls):
env_factory = DiscreteTestEnvFactory() env_factory = DiscreteTestEnvFactory()
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100) sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
builder = builder_cls( builder = builder_cls(
experiment_config=RLExperimentConfig(), experiment_config=ExperimentConfig(),
env_factory=env_factory, env_factory=env_factory,
sampling_config=sampling_config, sampling_config=sampling_config,
) )

View File

@ -6,7 +6,7 @@ import gymnasium
import torch import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
@ -54,7 +54,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class AgentFactory(ABC, ToStringMixin): class AgentFactory(ABC, ToStringMixin):
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory): def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory):
self.sampling_config = sampling_config self.sampling_config = sampling_config
self.optim_factory = optim_factory self.optim_factory = optim_factory
self.policy_wrapper_factory: PolicyWrapperFactory | None = None self.policy_wrapper_factory: PolicyWrapperFactory | None = None
@ -352,7 +352,7 @@ class ActorCriticAgentFactory(
def __init__( def __init__(
self, self,
params: TParams, params: TParams,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
@ -399,7 +399,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
def __init__( def __init__(
self, self,
params: A2CParams, params: A2CParams,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
@ -423,7 +423,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
def __init__( def __init__(
self, self,
params: PPOParams, params: PPOParams,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory, optimizer_factory: OptimizerFactory,
@ -447,7 +447,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
def __init__( def __init__(
self, self,
params: DQNParams, params: DQNParams,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
): ):
@ -483,7 +483,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
def __init__( def __init__(
self, self,
params: DDPGParams, params: DDPGParams,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic_factory: CriticFactory, critic_factory: CriticFactory,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
@ -526,7 +526,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__( def __init__(
self, self,
params: SACParams, params: SACParams,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic1_factory: CriticFactory, critic1_factory: CriticFactory,
critic2_factory: CriticFactory, critic2_factory: CriticFactory,
@ -575,7 +575,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__( def __init__(
self, self,
params: TD3Params, params: TD3Params,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic1_factory: CriticFactory, critic1_factory: CriticFactory,
critic2_factory: CriticFactory, critic2_factory: CriticFactory,

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass
@dataclass @dataclass
class RLSamplingConfig: class SamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching.""" """Sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are reasonable defaults? # TODO: What are reasonable defaults?

View File

@ -18,7 +18,7 @@ from tianshou.highlevel.agent import (
SACAgentFactory, SACAgentFactory,
TD3AgentFactory, TD3AgentFactory,
) )
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
@ -53,7 +53,7 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
@dataclass @dataclass
class RLExperimentConfig: class ExperimentConfig:
"""Generic config for setting up the experiment, not RL or training specific.""" """Generic config for setting up the experiment, not RL or training specific."""
seed: int = 42 seed: int = 42
@ -69,10 +69,10 @@ class RLExperimentConfig:
watch_num_episodes = 10 watch_num_episodes = 10
class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin): class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
def __init__( def __init__(
self, self,
config: RLExperimentConfig, config: ExperimentConfig,
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments], env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
agent_factory: AgentFactory, agent_factory: AgentFactory,
logger_factory: LoggerFactory | None = None, logger_factory: LoggerFactory | None = None,
@ -153,12 +153,12 @@ class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder") TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
class RLExperimentBuilder: class ExperimentBuilder:
def __init__( def __init__(
self, self,
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
): ):
self._config = experiment_config self._config = experiment_config
self._env_factory = env_factory self._env_factory = env_factory
@ -223,12 +223,12 @@ class RLExperimentBuilder:
else: else:
return self._optim_factory return self._optim_factory
def build(self) -> RLExperiment: def build(self) -> Experiment:
agent_factory = self._create_agent_factory() agent_factory = self._create_agent_factory()
agent_factory.set_trainer_callbacks(self._trainer_callbacks) agent_factory.set_trainer_callbacks(self._trainer_callbacks)
if self._policy_wrapper_factory: if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
experiment = RLExperiment( experiment = Experiment(
self._config, self._config,
self._env_factory, self._env_factory,
agent_factory, agent_factory,
@ -394,15 +394,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
class A2CExperimentBuilder( class A2CExperimentBuilder(
RLExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinSingleCriticCanUseActorFactory, _BuilderMixinSingleCriticCanUseActorFactory,
): ):
def __init__( def __init__(
self, self,
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
env_config: PersistableConfigProtocol | None = None, env_config: PersistableConfigProtocol | None = None,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
@ -428,15 +428,15 @@ class A2CExperimentBuilder(
class PPOExperimentBuilder( class PPOExperimentBuilder(
RLExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinSingleCriticCanUseActorFactory, _BuilderMixinSingleCriticCanUseActorFactory,
): ):
def __init__( def __init__(
self, self,
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
@ -460,14 +460,14 @@ class PPOExperimentBuilder(
class DQNExperimentBuilder( class DQNExperimentBuilder(
RLExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory, _BuilderMixinActorFactory,
): ):
def __init__( def __init__(
self, self,
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED) _BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
@ -488,15 +488,15 @@ class DQNExperimentBuilder(
class DDPGExperimentBuilder( class DDPGExperimentBuilder(
RLExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinActorFactory_ContinuousDeterministic,
_BuilderMixinSingleCriticCanUseActorFactory, _BuilderMixinSingleCriticCanUseActorFactory,
): ):
def __init__( def __init__(
self, self,
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
@ -519,15 +519,15 @@ class DDPGExperimentBuilder(
class SACExperimentBuilder( class SACExperimentBuilder(
RLExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinDualCriticFactory, _BuilderMixinDualCriticFactory,
): ):
def __init__( def __init__(
self, self,
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
@ -550,15 +550,15 @@ class SACExperimentBuilder(
class TD3ExperimentBuilder( class TD3ExperimentBuilder(
RLExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinActorFactory_ContinuousDeterministic,
_BuilderMixinDualCriticFactory, _BuilderMixinDualCriticFactory,
): ):
def __init__( def __init__(
self, self,
experiment_config: RLExperimentConfig, experiment_config: ExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, sampling_config: SamplingConfig,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)

View File

@ -4,7 +4,7 @@ import numpy as np
import torch import torch
from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -15,7 +15,7 @@ class LRSchedulerFactory(ToStringMixin, ABC):
class LRSchedulerFactoryLinear(LRSchedulerFactory): class LRSchedulerFactoryLinear(LRSchedulerFactory):
def __init__(self, sampling_config: RLSamplingConfig): def __init__(self, sampling_config: SamplingConfig):
self.sampling_config = sampling_config self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: