Remove 'RL' prefix from class names

This commit is contained in:
Dominik Jain 2023-10-06 13:50:23 +02:00
parent 50ac385321
commit d269063e6a
15 changed files with 79 additions and 79 deletions

View File

@ -10,10 +10,10 @@ from examples.atari.atari_network import (
FeatureNetFactoryDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DQNExperimentBuilder,
RLExperimentConfig,
ExperimentConfig,
)
from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import (
@ -25,7 +25,7 @@ from tianshou.utils import logging
def main(
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
eps_test: float = 0.005,
@ -52,7 +52,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,

View File

@ -11,10 +11,10 @@ from examples.atari.atari_network import (
FeatureNetFactoryDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
PPOExperimentBuilder,
RLExperimentConfig,
ExperimentConfig,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
@ -25,7 +25,7 @@ from tianshou.utils import logging
def main(
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: bool = True,
buffer_size: int = 100000,
@ -59,7 +59,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,

View File

@ -9,7 +9,7 @@ import gymnasium as gym
import numpy as np
from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
@ -379,7 +379,7 @@ class AtariEnvFactory(EnvFactory):
self,
task: str,
seed: int,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
frame_stack: int,
scale: int = 0,
):

View File

@ -8,10 +8,10 @@ from typing import Literal
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
RLExperimentConfig,
ExperimentConfig,
)
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
@ -20,7 +20,7 @@ from tianshou.utils import logging
def main(
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
@ -44,7 +44,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,

View File

@ -7,10 +7,10 @@ from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DDPGExperimentBuilder,
RLExperimentConfig,
ExperimentConfig,
)
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
from tianshou.highlevel.params.policy_params import DDPGParams
@ -18,7 +18,7 @@ from tianshou.utils import logging
def main(
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256),
@ -40,7 +40,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,

View File

@ -3,7 +3,7 @@ import warnings
import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
try:
@ -41,7 +41,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
class MujocoEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig):
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
self.task = task
self.sampling_config = sampling_config
self.seed = seed

View File

@ -9,10 +9,10 @@ from jsonargparse import CLI
from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
PPOExperimentBuilder,
RLExperimentConfig,
ExperimentConfig,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
@ -20,7 +20,7 @@ from tianshou.utils import logging
def main(
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
task: str = "Ant-v4",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
@ -49,7 +49,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,

View File

@ -7,9 +7,9 @@ from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
RLExperimentConfig,
ExperimentConfig,
SACExperimentBuilder,
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
@ -18,7 +18,7 @@ from tianshou.utils import logging
def main(
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256),
@ -42,7 +42,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
num_train_envs=training_num,

View File

@ -7,9 +7,9 @@ from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
RLExperimentConfig,
ExperimentConfig,
TD3ExperimentBuilder,
)
from tianshou.highlevel.params.env_param import MaxActionScaled
@ -21,7 +21,7 @@ from tianshou.utils import logging
def main(
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256),
@ -46,7 +46,7 @@ def main(
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
sampling_config = RLSamplingConfig(
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
num_train_envs=training_num,

View File

@ -2,12 +2,12 @@ from test.highlevel.env_factory import ContinuousTestEnvFactory
import pytest
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DDPGExperimentBuilder,
PPOExperimentBuilder,
RLExperimentConfig,
ExperimentConfig,
SACExperimentBuilder,
TD3ExperimentBuilder,
)
@ -25,8 +25,8 @@ from tianshou.highlevel.experiment import (
)
def test_experiment_builder_continuous_default_params(builder_cls):
env_factory = ContinuousTestEnvFactory()
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
experiment_config = RLExperimentConfig()
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
experiment_config = ExperimentConfig()
builder = builder_cls(
experiment_config=experiment_config,
env_factory=env_factory,

View File

@ -2,12 +2,12 @@ from test.highlevel.env_factory import DiscreteTestEnvFactory
import pytest
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
A2CExperimentBuilder,
DQNExperimentBuilder,
PPOExperimentBuilder,
RLExperimentConfig,
ExperimentConfig,
)
@ -17,9 +17,9 @@ from tianshou.highlevel.experiment import (
)
def test_experiment_builder_discrete_default_params(builder_cls):
env_factory = DiscreteTestEnvFactory()
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
builder = builder_cls(
experiment_config=RLExperimentConfig(),
experiment_config=ExperimentConfig(),
env_factory=env_factory,
sampling_config=sampling_config,
)

View File

@ -6,7 +6,7 @@ import gymnasium
import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module.actor import (
@ -54,7 +54,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class AgentFactory(ABC, ToStringMixin):
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory):
self.sampling_config = sampling_config
self.optim_factory = optim_factory
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
@ -352,7 +352,7 @@ class ActorCriticAgentFactory(
def __init__(
self,
params: TParams,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
@ -399,7 +399,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
def __init__(
self,
params: A2CParams,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
@ -423,7 +423,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
def __init__(
self,
params: PPOParams,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
@ -447,7 +447,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
params: DQNParams,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
optim_factory: OptimizerFactory,
):
@ -483,7 +483,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
def __init__(
self,
params: DDPGParams,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
@ -526,7 +526,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__(
self,
params: SACParams,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
@ -575,7 +575,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
def __init__(
self,
params: TD3Params,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass
@dataclass
class RLSamplingConfig:
class SamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are reasonable defaults?

View File

@ -18,7 +18,7 @@ from tianshou.highlevel.agent import (
SACAgentFactory,
TD3AgentFactory,
)
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.module.actor import (
@ -53,7 +53,7 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
@dataclass
class RLExperimentConfig:
class ExperimentConfig:
"""Generic config for setting up the experiment, not RL or training specific."""
seed: int = 42
@ -69,10 +69,10 @@ class RLExperimentConfig:
watch_num_episodes = 10
class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
def __init__(
self,
config: RLExperimentConfig,
config: ExperimentConfig,
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
agent_factory: AgentFactory,
logger_factory: LoggerFactory | None = None,
@ -153,12 +153,12 @@ class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
class RLExperimentBuilder:
class ExperimentBuilder:
def __init__(
self,
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
):
self._config = experiment_config
self._env_factory = env_factory
@ -223,12 +223,12 @@ class RLExperimentBuilder:
else:
return self._optim_factory
def build(self) -> RLExperiment:
def build(self) -> Experiment:
agent_factory = self._create_agent_factory()
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
experiment = RLExperiment(
experiment = Experiment(
self._config,
self._env_factory,
agent_factory,
@ -394,15 +394,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
class A2CExperimentBuilder(
RLExperimentBuilder,
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinSingleCriticCanUseActorFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
env_config: PersistableConfigProtocol | None = None,
):
super().__init__(experiment_config, env_factory, sampling_config)
@ -428,15 +428,15 @@ class A2CExperimentBuilder(
class PPOExperimentBuilder(
RLExperimentBuilder,
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinSingleCriticCanUseActorFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
@ -460,14 +460,14 @@ class PPOExperimentBuilder(
class DQNExperimentBuilder(
RLExperimentBuilder,
ExperimentBuilder,
_BuilderMixinActorFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
@ -488,15 +488,15 @@ class DQNExperimentBuilder(
class DDPGExperimentBuilder(
RLExperimentBuilder,
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic,
_BuilderMixinSingleCriticCanUseActorFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
@ -519,15 +519,15 @@ class DDPGExperimentBuilder(
class SACExperimentBuilder(
RLExperimentBuilder,
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousGaussian,
_BuilderMixinDualCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
@ -550,15 +550,15 @@ class SACExperimentBuilder(
class TD3ExperimentBuilder(
RLExperimentBuilder,
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic,
_BuilderMixinDualCriticFactory,
):
def __init__(
self,
experiment_config: RLExperimentConfig,
experiment_config: ExperimentConfig,
env_factory: EnvFactory,
sampling_config: RLSamplingConfig,
sampling_config: SamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)

View File

@ -4,7 +4,7 @@ import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.config import SamplingConfig
from tianshou.utils.string import ToStringMixin
@ -15,7 +15,7 @@ class LRSchedulerFactory(ToStringMixin, ABC):
class LRSchedulerFactoryLinear(LRSchedulerFactory):
def __init__(self, sampling_config: RLSamplingConfig):
def __init__(self, sampling_config: SamplingConfig):
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: