Refactoring, dropping package config

This commit is contained in:
Dominik Jain 2023-09-20 13:15:06 +02:00
parent 316eb3c579
commit 997b520580
12 changed files with 127 additions and 148 deletions

View File

@ -2,9 +2,9 @@ import warnings
import gymnasium as gym
from tianshou.config import BasicExperimentConfig, RLSamplingConfig
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.experiment import RLSamplingConfig
try:
import envpool
@ -41,14 +41,15 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
class MujocoEnvFactory(EnvFactory):
def __init__(self, experiment_config: BasicExperimentConfig, sampling_config: RLSamplingConfig):
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig):
self.task = task
self.sampling_config = sampling_config
self.experiment_config = experiment_config
self.seed = seed
def create_envs(self) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
task=self.experiment_config.task,
seed=self.experiment_config.seed,
task=self.task,
seed=self.seed,
num_train_envs=self.sampling_config.num_train_envs,
num_test_envs=self.sampling_config.num_test_envs,
obs_norm=True,

View File

@ -9,17 +9,13 @@ from jsonargparse import CLI
from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.config import (
BasicExperimentConfig,
LoggerConfig,
PGConfig,
PPOConfig,
RLAgentConfig,
from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAgentConfig
from tianshou.highlevel.experiment import (
RLExperiment,
RLExperimentConfig,
RLSamplingConfig,
)
from tianshou.highlevel.agent import PPOAgentFactory
from tianshou.highlevel.experiment import RLExperiment
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
@ -35,19 +31,20 @@ class NNConfig:
def main(
experiment_config: BasicExperimentConfig,
experiment_config: RLExperimentConfig,
logger_config: LoggerConfig,
sampling_config: RLSamplingConfig,
general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
nn_config: NNConfig,
task: str = "Ant-v4",
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(experiment_config.task, "ppo", str(experiment_config.seed), now)
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory(logger_config)
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
def dist_fn(*logits):
return Independent(Normal(*logits), 1)

View File

@ -7,14 +7,13 @@ from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.config import (
BasicExperimentConfig,
LoggerConfig,
from tianshou.highlevel.agent import SACAgentFactory, SACConfig
from tianshou.highlevel.experiment import (
RLExperiment,
RLExperimentConfig,
RLSamplingConfig,
)
from tianshou.highlevel.agent import SACAgentFactory
from tianshou.highlevel.experiment import RLExperiment
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
@ -23,17 +22,18 @@ from tianshou.highlevel.optim import AdamOptimizerFactory
def main(
experiment_config: BasicExperimentConfig,
experiment_config: RLExperimentConfig,
logger_config: LoggerConfig,
sampling_config: RLSamplingConfig,
sac_config: SACAgentFactory.Config,
sac_config: SACConfig,
hidden_sizes: Sequence[int] = (256, 256),
task: str = "Ant-v4",
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now)
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory(logger_config)
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
critic_factory = ContinuousNetCriticFactory(hidden_sizes)

View File

@ -1,10 +0,0 @@
__all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"]
from .config import (
BasicExperimentConfig,
PGConfig,
PPOConfig,
RLAgentConfig,
RLSamplingConfig,
LoggerConfig,
)

View File

@ -1,86 +0,0 @@
from dataclasses import dataclass
from typing import Literal
import torch
from jsonargparse import set_docstring_parse_options
set_docstring_parse_options(attribute_docstrings=True)
@dataclass
class BasicExperimentConfig:
"""Generic config for setting up the experiment, not RL or training specific."""
seed: int = 42
task: str = "Ant-v4"
"""Mujoco specific"""
render: float | None = 0.0
"""Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
resume_path: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
watch: bool = False
"""If True, will not perform training and only watch the restored policy"""
watch_num_episodes = 10
@dataclass
class LoggerConfig:
"""Logging config."""
logdir: str = "log"
logger: Literal["tensorboard", "wandb"] = "tensorboard"
wandb_project: str = "mujoco.benchmark"
"""Only used if logger is wandb."""
@dataclass
class RLSamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
num_epochs: int = 100
step_per_epoch: int = 30000
batch_size: int = 64
num_train_envs: int = 64
num_test_envs: int = 10
buffer_size: int = 4096
step_per_collect: int = 2048
repeat_per_collect: int = 10
update_per_step: int = 1
@dataclass
class RLAgentConfig:
"""Config common to most RL algorithms."""
gamma: float = 0.99
"""Discount factor"""
gae_lambda: float = 0.95
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
rew_norm: bool = True
"""Whether to normalize rewards"""
@dataclass
class PGConfig:
"""Config of general policy-gradient algorithms."""
ent_coef: float = 0.0
vf_coef: float = 0.25
max_grad_norm: float = 0.5
@dataclass
class PPOConfig:
"""PPO specific config."""
value_clip: bool = False
norm_adv: bool = False
"""Whether to normalize advantages"""
eps_clip: float = 0.2
dual_clip: float | None = None
recompute_adv: bool = True

View File

@ -0,0 +1,3 @@
from jsonargparse import set_docstring_parse_options
set_docstring_parse_options(attribute_docstrings=True)

View File

@ -2,13 +2,14 @@ import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Literal
import torch
from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.experiment import RLSamplingConfig
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
@ -124,6 +125,41 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
)
@dataclass
class RLAgentConfig:
"""Config common to most RL algorithms."""
gamma: float = 0.99
"""Discount factor"""
gae_lambda: float = 0.95
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
rew_norm: bool = True
"""Whether to normalize rewards"""
@dataclass
class PGConfig:
"""Config of general policy-gradient algorithms."""
ent_coef: float = 0.0
vf_coef: float = 0.25
max_grad_norm: float = 0.5
@dataclass
class PPOConfig:
"""PPO specific config."""
value_clip: bool = False
norm_adv: bool = False
"""Whether to normalize advantages"""
eps_clip: float = 0.2
dual_clip: float | None = None
recompute_adv: bool = True
class PPOAgentFactory(OnpolicyAgentFactory):
def __init__(
self,
@ -186,10 +222,22 @@ class PPOAgentFactory(OnpolicyAgentFactory):
)
class SACConfig:
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
reward_normalization: bool = False
estimation_step: int = 1
deterministic_eval: bool = True
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
class SACAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
config: "SACAgentFactory.Config",
config: SACConfig,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
@ -227,17 +275,3 @@ class SACAgentFactory(OffpolicyAgentFactory):
deterministic_eval=self.config.deterministic_eval,
exploration_noise=self.exploration_noise,
)
@dataclass
class Config:
"""SAC configuration."""
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
reward_normalization: bool = False
estimation_step: int = 1
deterministic_eval: bool = True
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3

View File

@ -1,12 +1,10 @@
from dataclasses import dataclass
from pprint import pprint
from typing import Generic, TypeVar
import numpy as np
import torch
from tianshou.config import (
BasicExperimentConfig,
)
from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory
from tianshou.highlevel.env import EnvFactory
@ -18,10 +16,42 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
@dataclass
class RLExperimentConfig:
"""Generic config for setting up the experiment, not RL or training specific."""
seed: int = 42
render: float | None = 0.0
"""Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
resume_path: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
watch: bool = False
"""If True, will not perform training and only watch the restored policy"""
watch_num_episodes = 10
@dataclass
class RLSamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
num_epochs: int = 100
step_per_epoch: int = 30000
batch_size: int = 64
num_train_envs: int = 64
num_test_envs: int = 10
buffer_size: int = 4096
step_per_collect: int = 2048
repeat_per_collect: int = 10
update_per_step: int = 1
class RLExperiment(Generic[TPolicy, TTrainer]):
def __init__(
self,
config: BasicExperimentConfig,
config: RLExperimentConfig,
env_factory: EnvFactory,
logger_factory: LoggerFactory,
agent_factory: AgentFactory,

View File

@ -1,10 +1,10 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Literal
from torch.utils.tensorboard import SummaryWriter
from tianshou.config import LoggerConfig
from tianshou.utils import TensorboardLogger, WandbLogger
TLogger = TensorboardLogger | WandbLogger
@ -22,11 +22,21 @@ class LoggerFactory(ABC):
pass
@dataclass
class LoggerConfig:
"""Logging config."""
logdir: str = "log"
logger: Literal["tensorboard", "wandb"] = "tensorboard"
wandb_project: str = "mujoco.benchmark"
"""Only used if logger is wandb."""
class DefaultLoggerFactory(LoggerFactory):
def __init__(self, config: LoggerConfig):
self.config = config
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
writer = SummaryWriter(self.config.logdir)
writer.add_text("args", str(self.config))
if self.config.logger == "wandb":

View File

@ -66,7 +66,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
actor = ActorProb(
net_a,
envs.get_action_shape(),
unbounded=True,
unbounded=self.unbounded,
device=device,
conditioned_sigma=self.conditioned_sigma,
).to(device)

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any, Type
from typing import Any
import numpy as np
import torch
@ -8,7 +8,7 @@ from torch import Tensor
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.config import RLSamplingConfig
from tianshou.highlevel.experiment import RLSamplingConfig
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]