Unify PPO configuration objects, use experiment-specific configuration

in mujoco_ppo_hl
This commit is contained in:
Dominik Jain 2023-09-20 15:45:09 +02:00
parent 8ec42009cb
commit 3fd60f9e70
4 changed files with 82 additions and 46 deletions

View File

@ -3,8 +3,8 @@ import warnings
import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
try:
import envpool

View File

@ -3,18 +3,18 @@
import datetime
import os
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal
from jsonargparse import CLI
from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAgentConfig
from tianshou.highlevel.agent import PPOAgentFactory, PPOConfig
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
RLExperiment,
RLExperimentConfig,
)
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
@ -23,45 +23,79 @@ from tianshou.highlevel.module import (
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
@dataclass
class NNConfig:
hidden_sizes: Sequence[int] = (64, 64)
lr: float = 3e-4
lr_decay: bool = True
def main(
experiment_config: RLExperimentConfig,
sampling_config: RLSamplingConfig,
general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
nn_config: NNConfig,
task: str = "Ant-v4",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 3e-4,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 10,
batch_size: int = 64,
training_num: int = 64,
test_num: int = 10,
rew_norm: bool = True,
vf_coef: float = 0.25,
ent_coef: float = 0.0,
gae_lambda: float = 0.95,
bound_action_method: Literal["clip", "tanh"] | None = "clip",
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
norm_adv: bool = False,
recompute_adv: bool = True,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory()
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
def dist_fn(*logits):
return Independent(Normal(*logits), 1)
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
ppo_config = PPOConfig(
gamma=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
rew_norm=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
norm_adv=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_adv=recompute_adv,
)
actor_factory = ContinuousActorProbFactory(hidden_sizes)
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
optim_factory = AdamOptimizerFactory()
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if nn_config.lr_decay else None
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if lr_decay else None
agent_factory = PPOAgentFactory(
general_config,
pg_config,
ppo_config,
sampling_config,
actor_factory,
critic_factory,
optim_factory,
dist_fn,
nn_config.lr,
lr,
lr_scheduler_factory,
)

View File

@ -8,11 +8,11 @@ from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.experiment import (
RLExperiment,
RLExperimentConfig,
)
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,

View File

@ -9,8 +9,8 @@ import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
@ -143,7 +143,7 @@ class RLAgentConfig:
@dataclass
class PGConfig:
class PGConfig(RLAgentConfig):
"""Config of general policy-gradient algorithms."""
ent_coef: float = 0.0
@ -152,7 +152,7 @@ class PGConfig:
@dataclass
class PPOConfig:
class PPOConfig(PGConfig):
"""PPO specific config."""
value_clip: bool = False
@ -166,9 +166,7 @@ class PPOConfig:
class PPOAgentFactory(OnpolicyAgentFactory):
def __init__(
self,
general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
config: PPOConfig,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
@ -181,9 +179,7 @@ class PPOAgentFactory(OnpolicyAgentFactory):
self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory
self.actor_factory = actor_factory
self.ppo_config = ppo_config
self.pg_config = pg_config
self.general_config = general_config
self.config = config
self.lr = lr
self.lr_scheduler_factory = lr_scheduler_factory
self.dist_fn = dist_fn
@ -208,27 +204,30 @@ class PPOAgentFactory(OnpolicyAgentFactory):
action_space=envs.get_action_space(),
action_scaling=True,
# general_config
discount_factor=self.general_config.gamma,
gae_lambda=self.general_config.gae_lambda,
reward_normalization=self.general_config.rew_norm,
action_bound_method=self.general_config.action_bound_method,
discount_factor=self.config.gamma,
gae_lambda=self.config.gae_lambda,
reward_normalization=self.config.rew_norm,
action_bound_method=self.config.action_bound_method,
# pg_config
max_grad_norm=self.pg_config.max_grad_norm,
vf_coef=self.pg_config.vf_coef,
ent_coef=self.pg_config.ent_coef,
max_grad_norm=self.config.max_grad_norm,
vf_coef=self.config.vf_coef,
ent_coef=self.config.ent_coef,
# ppo_config
eps_clip=self.ppo_config.eps_clip,
value_clip=self.ppo_config.value_clip,
dual_clip=self.ppo_config.dual_clip,
advantage_normalization=self.ppo_config.norm_adv,
recompute_advantage=self.ppo_config.recompute_adv,
eps_clip=self.config.eps_clip,
value_clip=self.config.value_clip,
dual_clip=self.config.dual_clip,
advantage_normalization=self.config.norm_adv,
recompute_advantage=self.config.recompute_adv,
)
class AutoAlphaFactory(ABC):
@abstractmethod
def create_auto_alpha(
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
):
pass
@ -238,7 +237,10 @@ class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
self.lr = lr
def create_auto_alpha(
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
log_alpha = torch.zeros(1, requires_grad=True, device=device)