Refactoring, dropping package config

This commit is contained in:
Dominik Jain 2023-09-20 13:15:06 +02:00
parent 316eb3c579
commit 997b520580
12 changed files with 127 additions and 148 deletions

@ -2,9 +2,9 @@ import warnings
import gymnasium as gym import gymnasium as gym
from tianshou.config import BasicExperimentConfig, RLSamplingConfig
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.experiment import RLSamplingConfig
try: try:
import envpool import envpool
@ -41,14 +41,15 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
class MujocoEnvFactory(EnvFactory): class MujocoEnvFactory(EnvFactory):
def __init__(self, experiment_config: BasicExperimentConfig, sampling_config: RLSamplingConfig): def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig):
self.task = task
self.sampling_config = sampling_config self.sampling_config = sampling_config
self.experiment_config = experiment_config self.seed = seed
def create_envs(self) -> ContinuousEnvironments: def create_envs(self) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env( env, train_envs, test_envs = make_mujoco_env(
task=self.experiment_config.task, task=self.task,
seed=self.experiment_config.seed, seed=self.seed,
num_train_envs=self.sampling_config.num_train_envs, num_train_envs=self.sampling_config.num_train_envs,
num_test_envs=self.sampling_config.num_test_envs, num_test_envs=self.sampling_config.num_test_envs,
obs_norm=True, obs_norm=True,

@ -9,17 +9,13 @@ from jsonargparse import CLI
from torch.distributions import Independent, Normal from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.config import ( from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAgentConfig
BasicExperimentConfig, from tianshou.highlevel.experiment import (
LoggerConfig, RLExperiment,
PGConfig, RLExperimentConfig,
PPOConfig,
RLAgentConfig,
RLSamplingConfig, RLSamplingConfig,
) )
from tianshou.highlevel.agent import PPOAgentFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig
from tianshou.highlevel.experiment import RLExperiment
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import ( from tianshou.highlevel.module import (
ContinuousActorProbFactory, ContinuousActorProbFactory,
ContinuousNetCriticFactory, ContinuousNetCriticFactory,
@ -35,19 +31,20 @@ class NNConfig:
def main( def main(
experiment_config: BasicExperimentConfig, experiment_config: RLExperimentConfig,
logger_config: LoggerConfig, logger_config: LoggerConfig,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
general_config: RLAgentConfig, general_config: RLAgentConfig,
pg_config: PGConfig, pg_config: PGConfig,
ppo_config: PPOConfig, ppo_config: PPOConfig,
nn_config: NNConfig, nn_config: NNConfig,
task: str = "Ant-v4",
): ):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(experiment_config.task, "ppo", str(experiment_config.seed), now) log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory(logger_config) logger_factory = DefaultLoggerFactory(logger_config)
env_factory = MujocoEnvFactory(experiment_config, sampling_config) env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
def dist_fn(*logits): def dist_fn(*logits):
return Independent(Normal(*logits), 1) return Independent(Normal(*logits), 1)

@ -7,14 +7,13 @@ from collections.abc import Sequence
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.config import ( from tianshou.highlevel.agent import SACAgentFactory, SACConfig
BasicExperimentConfig, from tianshou.highlevel.experiment import (
LoggerConfig, RLExperiment,
RLExperimentConfig,
RLSamplingConfig, RLSamplingConfig,
) )
from tianshou.highlevel.agent import SACAgentFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig
from tianshou.highlevel.experiment import RLExperiment
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import ( from tianshou.highlevel.module import (
ContinuousActorProbFactory, ContinuousActorProbFactory,
ContinuousNetCriticFactory, ContinuousNetCriticFactory,
@ -23,17 +22,18 @@ from tianshou.highlevel.optim import AdamOptimizerFactory
def main( def main(
experiment_config: BasicExperimentConfig, experiment_config: RLExperimentConfig,
logger_config: LoggerConfig, logger_config: LoggerConfig,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
sac_config: SACAgentFactory.Config, sac_config: SACConfig,
hidden_sizes: Sequence[int] = (256, 256), hidden_sizes: Sequence[int] = (256, 256),
task: str = "Ant-v4",
): ):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now) log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory(logger_config) logger_factory = DefaultLoggerFactory(logger_config)
env_factory = MujocoEnvFactory(experiment_config, sampling_config) env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True) actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
critic_factory = ContinuousNetCriticFactory(hidden_sizes) critic_factory = ContinuousNetCriticFactory(hidden_sizes)

@ -1,10 +0,0 @@
__all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"]
from .config import (
BasicExperimentConfig,
PGConfig,
PPOConfig,
RLAgentConfig,
RLSamplingConfig,
LoggerConfig,
)

@ -1,86 +0,0 @@
from dataclasses import dataclass
from typing import Literal
import torch
from jsonargparse import set_docstring_parse_options
set_docstring_parse_options(attribute_docstrings=True)
@dataclass
class BasicExperimentConfig:
"""Generic config for setting up the experiment, not RL or training specific."""
seed: int = 42
task: str = "Ant-v4"
"""Mujoco specific"""
render: float | None = 0.0
"""Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
resume_path: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
watch: bool = False
"""If True, will not perform training and only watch the restored policy"""
watch_num_episodes = 10
@dataclass
class LoggerConfig:
"""Logging config."""
logdir: str = "log"
logger: Literal["tensorboard", "wandb"] = "tensorboard"
wandb_project: str = "mujoco.benchmark"
"""Only used if logger is wandb."""
@dataclass
class RLSamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
num_epochs: int = 100
step_per_epoch: int = 30000
batch_size: int = 64
num_train_envs: int = 64
num_test_envs: int = 10
buffer_size: int = 4096
step_per_collect: int = 2048
repeat_per_collect: int = 10
update_per_step: int = 1
@dataclass
class RLAgentConfig:
"""Config common to most RL algorithms."""
gamma: float = 0.99
"""Discount factor"""
gae_lambda: float = 0.95
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
rew_norm: bool = True
"""Whether to normalize rewards"""
@dataclass
class PGConfig:
"""Config of general policy-gradient algorithms."""
ent_coef: float = 0.0
vf_coef: float = 0.25
max_grad_norm: float = 0.5
@dataclass
class PPOConfig:
"""PPO specific config."""
value_clip: bool = False
norm_adv: bool = False
"""Whether to normalize advantages"""
eps_clip: float = 0.2
dual_clip: float | None = None
recompute_adv: bool = True

@ -0,0 +1,3 @@
from jsonargparse import set_docstring_parse_options
set_docstring_parse_options(attribute_docstrings=True)

@ -2,13 +2,14 @@ import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal
import torch import torch
from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import BaseNoise from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.experiment import RLSamplingConfig
from tianshou.highlevel.logger import Logger from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
@ -124,6 +125,41 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
) )
@dataclass
class RLAgentConfig:
"""Config common to most RL algorithms."""
gamma: float = 0.99
"""Discount factor"""
gae_lambda: float = 0.95
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
rew_norm: bool = True
"""Whether to normalize rewards"""
@dataclass
class PGConfig:
"""Config of general policy-gradient algorithms."""
ent_coef: float = 0.0
vf_coef: float = 0.25
max_grad_norm: float = 0.5
@dataclass
class PPOConfig:
"""PPO specific config."""
value_clip: bool = False
norm_adv: bool = False
"""Whether to normalize advantages"""
eps_clip: float = 0.2
dual_clip: float | None = None
recompute_adv: bool = True
class PPOAgentFactory(OnpolicyAgentFactory): class PPOAgentFactory(OnpolicyAgentFactory):
def __init__( def __init__(
self, self,
@ -186,10 +222,22 @@ class PPOAgentFactory(OnpolicyAgentFactory):
) )
class SACConfig:
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
reward_normalization: bool = False
estimation_step: int = 1
deterministic_eval: bool = True
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
class SACAgentFactory(OffpolicyAgentFactory): class SACAgentFactory(OffpolicyAgentFactory):
def __init__( def __init__(
self, self,
config: "SACAgentFactory.Config", config: SACConfig,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
actor_factory: ActorFactory, actor_factory: ActorFactory,
critic1_factory: CriticFactory, critic1_factory: CriticFactory,
@ -227,17 +275,3 @@ class SACAgentFactory(OffpolicyAgentFactory):
deterministic_eval=self.config.deterministic_eval, deterministic_eval=self.config.deterministic_eval,
exploration_noise=self.exploration_noise, exploration_noise=self.exploration_noise,
) )
@dataclass
class Config:
"""SAC configuration."""
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
reward_normalization: bool = False
estimation_step: int = 1
deterministic_eval: bool = True
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3

@ -1,12 +1,10 @@
from dataclasses import dataclass
from pprint import pprint from pprint import pprint
from typing import Generic, TypeVar from typing import Generic, TypeVar
import numpy as np import numpy as np
import torch import torch
from tianshou.config import (
BasicExperimentConfig,
)
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory from tianshou.highlevel.agent import AgentFactory
from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.env import EnvFactory
@ -18,10 +16,42 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
TTrainer = TypeVar("TTrainer", bound=BaseTrainer) TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
@dataclass
class RLExperimentConfig:
"""Generic config for setting up the experiment, not RL or training specific."""
seed: int = 42
render: float | None = 0.0
"""Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
resume_path: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
watch: bool = False
"""If True, will not perform training and only watch the restored policy"""
watch_num_episodes = 10
@dataclass
class RLSamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
num_epochs: int = 100
step_per_epoch: int = 30000
batch_size: int = 64
num_train_envs: int = 64
num_test_envs: int = 10
buffer_size: int = 4096
step_per_collect: int = 2048
repeat_per_collect: int = 10
update_per_step: int = 1
class RLExperiment(Generic[TPolicy, TTrainer]): class RLExperiment(Generic[TPolicy, TTrainer]):
def __init__( def __init__(
self, self,
config: BasicExperimentConfig, config: RLExperimentConfig,
env_factory: EnvFactory, env_factory: EnvFactory,
logger_factory: LoggerFactory, logger_factory: LoggerFactory,
agent_factory: AgentFactory, agent_factory: AgentFactory,

@ -1,10 +1,10 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.config import LoggerConfig
from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils import TensorboardLogger, WandbLogger
TLogger = TensorboardLogger | WandbLogger TLogger = TensorboardLogger | WandbLogger
@ -22,11 +22,21 @@ class LoggerFactory(ABC):
pass pass
@dataclass
class LoggerConfig:
"""Logging config."""
logdir: str = "log"
logger: Literal["tensorboard", "wandb"] = "tensorboard"
wandb_project: str = "mujoco.benchmark"
"""Only used if logger is wandb."""
class DefaultLoggerFactory(LoggerFactory): class DefaultLoggerFactory(LoggerFactory):
def __init__(self, config: LoggerConfig): def __init__(self, config: LoggerConfig):
self.config = config self.config = config
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger: def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
writer = SummaryWriter(self.config.logdir) writer = SummaryWriter(self.config.logdir)
writer.add_text("args", str(self.config)) writer.add_text("args", str(self.config))
if self.config.logger == "wandb": if self.config.logger == "wandb":

@ -66,7 +66,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
actor = ActorProb( actor = ActorProb(
net_a, net_a,
envs.get_action_shape(), envs.get_action_shape(),
unbounded=True, unbounded=self.unbounded,
device=device, device=device,
conditioned_sigma=self.conditioned_sigma, conditioned_sigma=self.conditioned_sigma,
).to(device) ).to(device)

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Type from typing import Any
import numpy as np import numpy as np
import torch import torch
@ -8,7 +8,7 @@ from torch import Tensor
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.config import RLSamplingConfig from tianshou.highlevel.experiment import RLSamplingConfig
TParams = Iterable[Tensor] | Iterable[dict[str, Any]] TParams = Iterable[Tensor] | Iterable[dict[str, Any]]