Remove 'RL' prefix from class names
This commit is contained in:
parent
50ac385321
commit
d269063e6a
@ -10,10 +10,10 @@ from examples.atari.atari_network import (
|
||||
FeatureNetFactoryDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
DQNExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
ExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_params import DQNParams
|
||||
from tianshou.highlevel.params.policy_wrapper import (
|
||||
@ -25,7 +25,7 @@ from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: int = 0,
|
||||
eps_test: float = 0.005,
|
||||
@ -52,7 +52,7 @@ def main(
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
|
||||
@ -11,10 +11,10 @@ from examples.atari.atari_network import (
|
||||
FeatureNetFactoryDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
ExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
@ -25,7 +25,7 @@ from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: bool = True,
|
||||
buffer_size: int = 100000,
|
||||
@ -59,7 +59,7 @@ def main(
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
|
||||
@ -9,7 +9,7 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||
|
||||
@ -379,7 +379,7 @@ class AtariEnvFactory(EnvFactory):
|
||||
self,
|
||||
task: str,
|
||||
seed: int,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
frame_stack: int,
|
||||
scale: int = 0,
|
||||
):
|
||||
|
||||
@ -8,10 +8,10 @@ from typing import Literal
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
A2CExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
ExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
@ -20,7 +20,7 @@ from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
buffer_size: int = 4096,
|
||||
hidden_sizes: Sequence[int] = (64, 64),
|
||||
@ -44,7 +44,7 @@ def main(
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
|
||||
@ -7,10 +7,10 @@ from collections.abc import Sequence
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
DDPGExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
ExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
||||
from tianshou.highlevel.params.policy_params import DDPGParams
|
||||
@ -18,7 +18,7 @@ from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
buffer_size: int = 1000000,
|
||||
hidden_sizes: Sequence[int] = (256, 256),
|
||||
@ -40,7 +40,7 @@ def main(
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
|
||||
@ -3,7 +3,7 @@ import warnings
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||
|
||||
try:
|
||||
@ -41,7 +41,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactory):
|
||||
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig):
|
||||
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
|
||||
self.task = task
|
||||
self.sampling_config = sampling_config
|
||||
self.seed = seed
|
||||
|
||||
@ -9,10 +9,10 @@ from jsonargparse import CLI
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
ExperimentConfig,
|
||||
)
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
@ -20,7 +20,7 @@ from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v4",
|
||||
buffer_size: int = 4096,
|
||||
hidden_sizes: Sequence[int] = (64, 64),
|
||||
@ -49,7 +49,7 @@ def main(
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
|
||||
@ -7,9 +7,9 @@ from collections.abc import Sequence
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
RLExperimentConfig,
|
||||
ExperimentConfig,
|
||||
SACExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
||||
@ -18,7 +18,7 @@ from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
buffer_size: int = 1000000,
|
||||
hidden_sizes: Sequence[int] = (256, 256),
|
||||
@ -42,7 +42,7 @@ def main(
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
num_train_envs=training_num,
|
||||
|
||||
@ -7,9 +7,9 @@ from collections.abc import Sequence
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
RLExperimentConfig,
|
||||
ExperimentConfig,
|
||||
TD3ExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.env_param import MaxActionScaled
|
||||
@ -21,7 +21,7 @@ from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
buffer_size: int = 1000000,
|
||||
hidden_sizes: Sequence[int] = (256, 256),
|
||||
@ -46,7 +46,7 @@ def main(
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
|
||||
|
||||
sampling_config = RLSamplingConfig(
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
num_train_envs=training_num,
|
||||
|
||||
@ -2,12 +2,12 @@ from test.highlevel.env_factory import ContinuousTestEnvFactory
|
||||
|
||||
import pytest
|
||||
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
A2CExperimentBuilder,
|
||||
DDPGExperimentBuilder,
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
ExperimentConfig,
|
||||
SACExperimentBuilder,
|
||||
TD3ExperimentBuilder,
|
||||
)
|
||||
@ -25,8 +25,8 @@ from tianshou.highlevel.experiment import (
|
||||
)
|
||||
def test_experiment_builder_continuous_default_params(builder_cls):
|
||||
env_factory = ContinuousTestEnvFactory()
|
||||
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
experiment_config = RLExperimentConfig()
|
||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
experiment_config = ExperimentConfig()
|
||||
builder = builder_cls(
|
||||
experiment_config=experiment_config,
|
||||
env_factory=env_factory,
|
||||
|
||||
@ -2,12 +2,12 @@ from test.highlevel.env_factory import DiscreteTestEnvFactory
|
||||
|
||||
import pytest
|
||||
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
A2CExperimentBuilder,
|
||||
DQNExperimentBuilder,
|
||||
PPOExperimentBuilder,
|
||||
RLExperimentConfig,
|
||||
ExperimentConfig,
|
||||
)
|
||||
|
||||
|
||||
@ -17,9 +17,9 @@ from tianshou.highlevel.experiment import (
|
||||
)
|
||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||
env_factory = DiscreteTestEnvFactory()
|
||||
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||
builder = builder_cls(
|
||||
experiment_config=RLExperimentConfig(),
|
||||
experiment_config=ExperimentConfig(),
|
||||
env_factory=env_factory,
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ import gymnasium
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
@ -54,7 +54,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
|
||||
|
||||
class AgentFactory(ABC, ToStringMixin):
|
||||
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
|
||||
def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory):
|
||||
self.sampling_config = sampling_config
|
||||
self.optim_factory = optim_factory
|
||||
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
@ -352,7 +352,7 @@ class ActorCriticAgentFactory(
|
||||
def __init__(
|
||||
self,
|
||||
params: TParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
@ -399,7 +399,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
||||
def __init__(
|
||||
self,
|
||||
params: A2CParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
@ -423,7 +423,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
||||
def __init__(
|
||||
self,
|
||||
params: PPOParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
@ -447,7 +447,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: DQNParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
@ -483,7 +483,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
||||
def __init__(
|
||||
self,
|
||||
params: DDPGParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
@ -526,7 +526,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
def __init__(
|
||||
self,
|
||||
params: SACParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic1_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
@ -575,7 +575,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
||||
def __init__(
|
||||
self,
|
||||
params: TD3Params,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic1_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
|
||||
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLSamplingConfig:
|
||||
class SamplingConfig:
|
||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||
|
||||
# TODO: What are reasonable defaults?
|
||||
|
||||
@ -18,7 +18,7 @@ from tianshou.highlevel.agent import (
|
||||
SACAgentFactory,
|
||||
TD3AgentFactory,
|
||||
)
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory, Environments
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||
from tianshou.highlevel.module.actor import (
|
||||
@ -53,7 +53,7 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLExperimentConfig:
|
||||
class ExperimentConfig:
|
||||
"""Generic config for setting up the experiment, not RL or training specific."""
|
||||
|
||||
seed: int = 42
|
||||
@ -69,10 +69,10 @@ class RLExperimentConfig:
|
||||
watch_num_episodes = 10
|
||||
|
||||
|
||||
class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: RLExperimentConfig,
|
||||
config: ExperimentConfig,
|
||||
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
|
||||
agent_factory: AgentFactory,
|
||||
logger_factory: LoggerFactory | None = None,
|
||||
@ -153,12 +153,12 @@ class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
||||
|
||||
|
||||
class RLExperimentBuilder:
|
||||
class ExperimentBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
self._config = experiment_config
|
||||
self._env_factory = env_factory
|
||||
@ -223,12 +223,12 @@ class RLExperimentBuilder:
|
||||
else:
|
||||
return self._optim_factory
|
||||
|
||||
def build(self) -> RLExperiment:
|
||||
def build(self) -> Experiment:
|
||||
agent_factory = self._create_agent_factory()
|
||||
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
||||
if self._policy_wrapper_factory:
|
||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||
experiment = RLExperiment(
|
||||
experiment = Experiment(
|
||||
self._config,
|
||||
self._env_factory,
|
||||
agent_factory,
|
||||
@ -394,15 +394,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
||||
|
||||
|
||||
class A2CExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
env_config: PersistableConfigProtocol | None = None,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
@ -428,15 +428,15 @@ class A2CExperimentBuilder(
|
||||
|
||||
|
||||
class PPOExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
@ -460,14 +460,14 @@ class PPOExperimentBuilder(
|
||||
|
||||
|
||||
class DQNExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
||||
@ -488,15 +488,15 @@ class DQNExperimentBuilder(
|
||||
|
||||
|
||||
class DDPGExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||
@ -519,15 +519,15 @@ class DDPGExperimentBuilder(
|
||||
|
||||
|
||||
class SACExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||
_BuilderMixinDualCriticFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||
@ -550,15 +550,15 @@ class SACExperimentBuilder(
|
||||
|
||||
|
||||
class TD3ExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||
_BuilderMixinDualCriticFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_config: RLExperimentConfig,
|
||||
experiment_config: ExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sampling_config: SamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||
|
||||
@ -4,7 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ class LRSchedulerFactory(ToStringMixin, ABC):
|
||||
|
||||
|
||||
class LRSchedulerFactoryLinear(LRSchedulerFactory):
|
||||
def __init__(self, sampling_config: RLSamplingConfig):
|
||||
def __init__(self, sampling_config: SamplingConfig):
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user