Remove 'RL' prefix from class names
This commit is contained in:
parent
50ac385321
commit
d269063e6a
@ -10,10 +10,10 @@ from examples.atari.atari_network import (
|
|||||||
FeatureNetFactoryDQN,
|
FeatureNetFactoryDQN,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
DQNExperimentBuilder,
|
DQNExperimentBuilder,
|
||||||
RLExperimentConfig,
|
ExperimentConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_params import DQNParams
|
from tianshou.highlevel.params.policy_params import DQNParams
|
||||||
from tianshou.highlevel.params.policy_wrapper import (
|
from tianshou.highlevel.params.policy_wrapper import (
|
||||||
@ -25,7 +25,7 @@ from tianshou.utils import logging
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "PongNoFrameskip-v4",
|
task: str = "PongNoFrameskip-v4",
|
||||||
scale_obs: int = 0,
|
scale_obs: int = 0,
|
||||||
eps_test: float = 0.005,
|
eps_test: float = 0.005,
|
||||||
@ -52,7 +52,7 @@ def main(
|
|||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
|
|
||||||
sampling_config = RLSamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
step_per_epoch=step_per_epoch,
|
step_per_epoch=step_per_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|||||||
@ -11,10 +11,10 @@ from examples.atari.atari_network import (
|
|||||||
FeatureNetFactoryDQN,
|
FeatureNetFactoryDQN,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
RLExperimentConfig,
|
ExperimentConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
@ -25,7 +25,7 @@ from tianshou.utils import logging
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "PongNoFrameskip-v4",
|
task: str = "PongNoFrameskip-v4",
|
||||||
scale_obs: bool = True,
|
scale_obs: bool = True,
|
||||||
buffer_size: int = 100000,
|
buffer_size: int = 100000,
|
||||||
@ -59,7 +59,7 @@ def main(
|
|||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
|
|
||||||
sampling_config = RLSamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
step_per_epoch=step_per_epoch,
|
step_per_epoch=step_per_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import gymnasium as gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||||
|
|
||||||
@ -379,7 +379,7 @@ class AtariEnvFactory(EnvFactory):
|
|||||||
self,
|
self,
|
||||||
task: str,
|
task: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
frame_stack: int,
|
frame_stack: int,
|
||||||
scale: int = 0,
|
scale: int = 0,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -8,10 +8,10 @@ from typing import Literal
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
A2CExperimentBuilder,
|
A2CExperimentBuilder,
|
||||||
RLExperimentConfig,
|
ExperimentConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
|
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
@ -20,7 +20,7 @@ from tianshou.utils import logging
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "Ant-v3",
|
task: str = "Ant-v3",
|
||||||
buffer_size: int = 4096,
|
buffer_size: int = 4096,
|
||||||
hidden_sizes: Sequence[int] = (64, 64),
|
hidden_sizes: Sequence[int] = (64, 64),
|
||||||
@ -44,7 +44,7 @@ def main(
|
|||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
|
|
||||||
sampling_config = RLSamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
step_per_epoch=step_per_epoch,
|
step_per_epoch=step_per_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|||||||
@ -7,10 +7,10 @@ from collections.abc import Sequence
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
DDPGExperimentBuilder,
|
DDPGExperimentBuilder,
|
||||||
RLExperimentConfig,
|
ExperimentConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
||||||
from tianshou.highlevel.params.policy_params import DDPGParams
|
from tianshou.highlevel.params.policy_params import DDPGParams
|
||||||
@ -18,7 +18,7 @@ from tianshou.utils import logging
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "Ant-v3",
|
task: str = "Ant-v3",
|
||||||
buffer_size: int = 1000000,
|
buffer_size: int = 1000000,
|
||||||
hidden_sizes: Sequence[int] = (256, 256),
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
@ -40,7 +40,7 @@ def main(
|
|||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
|
|
||||||
sampling_config = RLSamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
step_per_epoch=step_per_epoch,
|
step_per_epoch=step_per_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import warnings
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -41,7 +41,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
|
|||||||
|
|
||||||
|
|
||||||
class MujocoEnvFactory(EnvFactory):
|
class MujocoEnvFactory(EnvFactory):
|
||||||
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig):
|
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
|
||||||
self.task = task
|
self.task = task
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|||||||
@ -9,10 +9,10 @@ from jsonargparse import CLI
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
RLExperimentConfig,
|
ExperimentConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
@ -20,7 +20,7 @@ from tianshou.utils import logging
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "Ant-v4",
|
task: str = "Ant-v4",
|
||||||
buffer_size: int = 4096,
|
buffer_size: int = 4096,
|
||||||
hidden_sizes: Sequence[int] = (64, 64),
|
hidden_sizes: Sequence[int] = (64, 64),
|
||||||
@ -49,7 +49,7 @@ def main(
|
|||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
|
|
||||||
sampling_config = RLSamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
step_per_epoch=step_per_epoch,
|
step_per_epoch=step_per_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|||||||
@ -7,9 +7,9 @@ from collections.abc import Sequence
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperimentConfig,
|
ExperimentConfig,
|
||||||
SACExperimentBuilder,
|
SACExperimentBuilder,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
||||||
@ -18,7 +18,7 @@ from tianshou.utils import logging
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "Ant-v3",
|
task: str = "Ant-v3",
|
||||||
buffer_size: int = 1000000,
|
buffer_size: int = 1000000,
|
||||||
hidden_sizes: Sequence[int] = (256, 256),
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
@ -42,7 +42,7 @@ def main(
|
|||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
||||||
|
|
||||||
sampling_config = RLSamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
step_per_epoch=step_per_epoch,
|
step_per_epoch=step_per_epoch,
|
||||||
num_train_envs=training_num,
|
num_train_envs=training_num,
|
||||||
|
|||||||
@ -7,9 +7,9 @@ from collections.abc import Sequence
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperimentConfig,
|
ExperimentConfig,
|
||||||
TD3ExperimentBuilder,
|
TD3ExperimentBuilder,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.env_param import MaxActionScaled
|
from tianshou.highlevel.params.env_param import MaxActionScaled
|
||||||
@ -21,7 +21,7 @@ from tianshou.utils import logging
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "Ant-v3",
|
task: str = "Ant-v3",
|
||||||
buffer_size: int = 1000000,
|
buffer_size: int = 1000000,
|
||||||
hidden_sizes: Sequence[int] = (256, 256),
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
@ -46,7 +46,7 @@ def main(
|
|||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "td3", str(experiment_config.seed), now)
|
||||||
|
|
||||||
sampling_config = RLSamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
step_per_epoch=step_per_epoch,
|
step_per_epoch=step_per_epoch,
|
||||||
num_train_envs=training_num,
|
num_train_envs=training_num,
|
||||||
|
|||||||
@ -2,12 +2,12 @@ from test.highlevel.env_factory import ContinuousTestEnvFactory
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
A2CExperimentBuilder,
|
A2CExperimentBuilder,
|
||||||
DDPGExperimentBuilder,
|
DDPGExperimentBuilder,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
RLExperimentConfig,
|
ExperimentConfig,
|
||||||
SACExperimentBuilder,
|
SACExperimentBuilder,
|
||||||
TD3ExperimentBuilder,
|
TD3ExperimentBuilder,
|
||||||
)
|
)
|
||||||
@ -25,8 +25,8 @@ from tianshou.highlevel.experiment import (
|
|||||||
)
|
)
|
||||||
def test_experiment_builder_continuous_default_params(builder_cls):
|
def test_experiment_builder_continuous_default_params(builder_cls):
|
||||||
env_factory = ContinuousTestEnvFactory()
|
env_factory = ContinuousTestEnvFactory()
|
||||||
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
|
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||||
experiment_config = RLExperimentConfig()
|
experiment_config = ExperimentConfig()
|
||||||
builder = builder_cls(
|
builder = builder_cls(
|
||||||
experiment_config=experiment_config,
|
experiment_config=experiment_config,
|
||||||
env_factory=env_factory,
|
env_factory=env_factory,
|
||||||
|
|||||||
@ -2,12 +2,12 @@ from test.highlevel.env_factory import DiscreteTestEnvFactory
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
A2CExperimentBuilder,
|
A2CExperimentBuilder,
|
||||||
DQNExperimentBuilder,
|
DQNExperimentBuilder,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
RLExperimentConfig,
|
ExperimentConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -17,9 +17,9 @@ from tianshou.highlevel.experiment import (
|
|||||||
)
|
)
|
||||||
def test_experiment_builder_discrete_default_params(builder_cls):
|
def test_experiment_builder_discrete_default_params(builder_cls):
|
||||||
env_factory = DiscreteTestEnvFactory()
|
env_factory = DiscreteTestEnvFactory()
|
||||||
sampling_config = RLSamplingConfig(num_epochs=1, step_per_epoch=100)
|
sampling_config = SamplingConfig(num_epochs=1, step_per_epoch=100)
|
||||||
builder = builder_cls(
|
builder = builder_cls(
|
||||||
experiment_config=RLExperimentConfig(),
|
experiment_config=ExperimentConfig(),
|
||||||
env_factory=env_factory,
|
env_factory=env_factory,
|
||||||
sampling_config=sampling_config,
|
sampling_config=sampling_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import gymnasium
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import Logger
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
@ -54,7 +54,7 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
|||||||
|
|
||||||
|
|
||||||
class AgentFactory(ABC, ToStringMixin):
|
class AgentFactory(ABC, ToStringMixin):
|
||||||
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
|
def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory):
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
@ -352,7 +352,7 @@ class ActorCriticAgentFactory(
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: TParams,
|
params: TParams,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
@ -399,7 +399,7 @@ class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: A2CParams,
|
params: A2CParams,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
@ -423,7 +423,7 @@ class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: PPOParams,
|
params: PPOParams,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
optimizer_factory: OptimizerFactory,
|
||||||
@ -447,7 +447,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: DQNParams,
|
params: DQNParams,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
@ -483,7 +483,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: DDPGParams,
|
params: DDPGParams,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
@ -526,7 +526,7 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: SACParams,
|
params: SACParams,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic1_factory: CriticFactory,
|
critic1_factory: CriticFactory,
|
||||||
critic2_factory: CriticFactory,
|
critic2_factory: CriticFactory,
|
||||||
@ -575,7 +575,7 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: TD3Params,
|
params: TD3Params,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic1_factory: CriticFactory,
|
critic1_factory: CriticFactory,
|
||||||
critic2_factory: CriticFactory,
|
critic2_factory: CriticFactory,
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RLSamplingConfig:
|
class SamplingConfig:
|
||||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||||
|
|
||||||
# TODO: What are reasonable defaults?
|
# TODO: What are reasonable defaults?
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from tianshou.highlevel.agent import (
|
|||||||
SACAgentFactory,
|
SACAgentFactory,
|
||||||
TD3AgentFactory,
|
TD3AgentFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory, Environments
|
from tianshou.highlevel.env import EnvFactory, Environments
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
@ -53,7 +53,7 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RLExperimentConfig:
|
class ExperimentConfig:
|
||||||
"""Generic config for setting up the experiment, not RL or training specific."""
|
"""Generic config for setting up the experiment, not RL or training specific."""
|
||||||
|
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
@ -69,10 +69,10 @@ class RLExperimentConfig:
|
|||||||
watch_num_episodes = 10
|
watch_num_episodes = 10
|
||||||
|
|
||||||
|
|
||||||
class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RLExperimentConfig,
|
config: ExperimentConfig,
|
||||||
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
|
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
|
||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
logger_factory: LoggerFactory | None = None,
|
logger_factory: LoggerFactory | None = None,
|
||||||
@ -153,12 +153,12 @@ class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
||||||
|
|
||||||
|
|
||||||
class RLExperimentBuilder:
|
class ExperimentBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
self._config = experiment_config
|
self._config = experiment_config
|
||||||
self._env_factory = env_factory
|
self._env_factory = env_factory
|
||||||
@ -223,12 +223,12 @@ class RLExperimentBuilder:
|
|||||||
else:
|
else:
|
||||||
return self._optim_factory
|
return self._optim_factory
|
||||||
|
|
||||||
def build(self) -> RLExperiment:
|
def build(self) -> Experiment:
|
||||||
agent_factory = self._create_agent_factory()
|
agent_factory = self._create_agent_factory()
|
||||||
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
||||||
if self._policy_wrapper_factory:
|
if self._policy_wrapper_factory:
|
||||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||||
experiment = RLExperiment(
|
experiment = Experiment(
|
||||||
self._config,
|
self._config,
|
||||||
self._env_factory,
|
self._env_factory,
|
||||||
agent_factory,
|
agent_factory,
|
||||||
@ -394,15 +394,15 @@ class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|||||||
|
|
||||||
|
|
||||||
class A2CExperimentBuilder(
|
class A2CExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory,
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
env_config: PersistableConfigProtocol | None = None,
|
env_config: PersistableConfigProtocol | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
@ -428,15 +428,15 @@ class A2CExperimentBuilder(
|
|||||||
|
|
||||||
|
|
||||||
class PPOExperimentBuilder(
|
class PPOExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory,
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
@ -460,14 +460,14 @@ class PPOExperimentBuilder(
|
|||||||
|
|
||||||
|
|
||||||
class DQNExperimentBuilder(
|
class DQNExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory,
|
_BuilderMixinActorFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
||||||
@ -488,15 +488,15 @@ class DQNExperimentBuilder(
|
|||||||
|
|
||||||
|
|
||||||
class DDPGExperimentBuilder(
|
class DDPGExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||||
_BuilderMixinSingleCriticCanUseActorFactory,
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||||
@ -519,15 +519,15 @@ class DDPGExperimentBuilder(
|
|||||||
|
|
||||||
|
|
||||||
class SACExperimentBuilder(
|
class SACExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian,
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
||||||
_BuilderMixinDualCriticFactory,
|
_BuilderMixinDualCriticFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
||||||
@ -550,15 +550,15 @@ class SACExperimentBuilder(
|
|||||||
|
|
||||||
|
|
||||||
class TD3ExperimentBuilder(
|
class TD3ExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||||
_BuilderMixinDualCriticFactory,
|
_BuilderMixinDualCriticFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
|
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ class LRSchedulerFactory(ToStringMixin, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class LRSchedulerFactoryLinear(LRSchedulerFactory):
|
class LRSchedulerFactoryLinear(LRSchedulerFactory):
|
||||||
def __init__(self, sampling_config: RLSamplingConfig):
|
def __init__(self, sampling_config: SamplingConfig):
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
|
|
||||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user