93 lines
2.4 KiB
Python
93 lines
2.4 KiB
Python
from dataclasses import dataclass
|
|
from typing import Literal, Optional, Sequence
|
|
|
|
import torch
|
|
from jsonargparse import set_docstring_parse_options
|
|
|
|
set_docstring_parse_options(attribute_docstrings=True)
|
|
|
|
|
|
@dataclass
|
|
class BasicExperimentConfig:
|
|
"""Generic config for setting up the experiment, not RL or training specific."""
|
|
|
|
seed: int = 42
|
|
task: str = "Ant-v4"
|
|
"""Mujoco specific"""
|
|
render: Optional[float] = 0.0
|
|
"""Milliseconds between rendered frames; if None, no rendering"""
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
resume_id: Optional[int] = None
|
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
|
resume_path: str = None
|
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
|
watch: bool = False
|
|
"""If True, will not perform training and only watch the restored policy"""
|
|
watch_num_episodes = 10
|
|
|
|
|
|
@dataclass
|
|
class LoggerConfig:
|
|
"""Logging config"""
|
|
|
|
logdir: str = "log"
|
|
logger: Literal["tensorboard", "wandb"] = "tensorboard"
|
|
wandb_project: str = "mujoco.benchmark"
|
|
"""Only used if logger is wandb."""
|
|
|
|
|
|
@dataclass
|
|
class RLSamplingConfig:
|
|
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
|
|
|
num_epochs: int = 100
|
|
step_per_epoch: int = 30000
|
|
batch_size: int = 64
|
|
num_train_envs: int = 64
|
|
num_test_envs: int = 10
|
|
buffer_size: int = 4096
|
|
step_per_collect: int = 2048
|
|
repeat_per_collect: int = 10
|
|
|
|
|
|
@dataclass
|
|
class RLAgentConfig:
|
|
"""Config common to most RL algorithms"""
|
|
|
|
gamma: float = 0.99
|
|
"""Discount factor"""
|
|
gae_lambda: float = 0.95
|
|
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
|
|
action_bound_method: Optional[Literal["clip", "tanh"]] = "clip"
|
|
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
|
|
rew_norm: bool = True
|
|
"""Whether to normalize rewards"""
|
|
|
|
|
|
@dataclass
|
|
class PGConfig:
|
|
"""Config of general policy-gradient algorithms"""
|
|
|
|
ent_coef: float = 0.0
|
|
vf_coef: float = 0.25
|
|
max_grad_norm: float = 0.5
|
|
|
|
|
|
@dataclass
|
|
class PPOConfig:
|
|
"""PPO specific config"""
|
|
|
|
value_clip: bool = False
|
|
norm_adv: bool = False
|
|
"""Whether to normalize advantages"""
|
|
eps_clip: float = 0.2
|
|
dual_clip: Optional[float] = None
|
|
recompute_adv: bool = True
|
|
|
|
|
|
@dataclass
|
|
class NNConfig:
|
|
hidden_sizes: Sequence[int] = (64, 64)
|
|
lr: float = 3e-4
|
|
lr_decay: bool = True
|