Unify PPO configuration objects, use experiment-specific configuration
in mujoco_ppo_hl
This commit is contained in:
parent
8ec42009cb
commit
3fd60f9e70
@ -3,8 +3,8 @@ import warnings
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
|
@ -3,18 +3,18 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from typing import Literal
|
||||||
|
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAgentConfig
|
from tianshou.highlevel.agent import PPOAgentFactory, PPOConfig
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperiment,
|
RLExperiment,
|
||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ContinuousActorProbFactory,
|
ContinuousActorProbFactory,
|
||||||
@ -23,45 +23,79 @@ from tianshou.highlevel.module import (
|
|||||||
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NNConfig:
|
|
||||||
hidden_sizes: Sequence[int] = (64, 64)
|
|
||||||
lr: float = 3e-4
|
|
||||||
lr_decay: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: RLExperimentConfig,
|
||||||
sampling_config: RLSamplingConfig,
|
|
||||||
general_config: RLAgentConfig,
|
|
||||||
pg_config: PGConfig,
|
|
||||||
ppo_config: PPOConfig,
|
|
||||||
nn_config: NNConfig,
|
|
||||||
task: str = "Ant-v4",
|
task: str = "Ant-v4",
|
||||||
|
buffer_size: int = 4096,
|
||||||
|
hidden_sizes: Sequence[int] = (64, 64),
|
||||||
|
lr: float = 3e-4,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
epoch: int = 100,
|
||||||
|
step_per_epoch: int = 30000,
|
||||||
|
step_per_collect: int = 2048,
|
||||||
|
repeat_per_collect: int = 10,
|
||||||
|
batch_size: int = 64,
|
||||||
|
training_num: int = 64,
|
||||||
|
test_num: int = 10,
|
||||||
|
rew_norm: bool = True,
|
||||||
|
vf_coef: float = 0.25,
|
||||||
|
ent_coef: float = 0.0,
|
||||||
|
gae_lambda: float = 0.95,
|
||||||
|
bound_action_method: Literal["clip", "tanh"] | None = "clip",
|
||||||
|
lr_decay: bool = True,
|
||||||
|
max_grad_norm: float = 0.5,
|
||||||
|
eps_clip: float = 0.2,
|
||||||
|
dual_clip: float | None = None,
|
||||||
|
value_clip: bool = False,
|
||||||
|
norm_adv: bool = False,
|
||||||
|
recompute_adv: bool = True,
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
logger_factory = DefaultLoggerFactory()
|
logger_factory = DefaultLoggerFactory()
|
||||||
|
|
||||||
|
sampling_config = RLSamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
repeat_per_collect=repeat_per_collect,
|
||||||
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
def dist_fn(*logits):
|
def dist_fn(*logits):
|
||||||
return Independent(Normal(*logits), 1)
|
return Independent(Normal(*logits), 1)
|
||||||
|
|
||||||
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
|
ppo_config = PPOConfig(
|
||||||
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
|
gamma=gamma,
|
||||||
|
gae_lambda=gae_lambda,
|
||||||
|
action_bound_method=bound_action_method,
|
||||||
|
rew_norm=rew_norm,
|
||||||
|
ent_coef=ent_coef,
|
||||||
|
vf_coef=vf_coef,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
|
value_clip=value_clip,
|
||||||
|
norm_adv=norm_adv,
|
||||||
|
eps_clip=eps_clip,
|
||||||
|
dual_clip=dual_clip,
|
||||||
|
recompute_adv=recompute_adv,
|
||||||
|
)
|
||||||
|
actor_factory = ContinuousActorProbFactory(hidden_sizes)
|
||||||
|
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
||||||
optim_factory = AdamOptimizerFactory()
|
optim_factory = AdamOptimizerFactory()
|
||||||
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if nn_config.lr_decay else None
|
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if lr_decay else None
|
||||||
agent_factory = PPOAgentFactory(
|
agent_factory = PPOAgentFactory(
|
||||||
general_config,
|
|
||||||
pg_config,
|
|
||||||
ppo_config,
|
ppo_config,
|
||||||
sampling_config,
|
sampling_config,
|
||||||
actor_factory,
|
actor_factory,
|
||||||
critic_factory,
|
critic_factory,
|
||||||
optim_factory,
|
optim_factory,
|
||||||
dist_fn,
|
dist_fn,
|
||||||
nn_config.lr,
|
lr,
|
||||||
lr_scheduler_factory,
|
lr_scheduler_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -8,11 +8,11 @@ from jsonargparse import CLI
|
|||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig
|
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig
|
||||||
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperiment,
|
RLExperiment,
|
||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ContinuousActorProbFactory,
|
ContinuousActorProbFactory,
|
||||||
|
@ -9,8 +9,8 @@ import torch
|
|||||||
|
|
||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
from tianshou.highlevel.env import Environments
|
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import Logger
|
||||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||||
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
||||||
@ -143,7 +143,7 @@ class RLAgentConfig:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGConfig:
|
class PGConfig(RLAgentConfig):
|
||||||
"""Config of general policy-gradient algorithms."""
|
"""Config of general policy-gradient algorithms."""
|
||||||
|
|
||||||
ent_coef: float = 0.0
|
ent_coef: float = 0.0
|
||||||
@ -152,7 +152,7 @@ class PGConfig:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PPOConfig:
|
class PPOConfig(PGConfig):
|
||||||
"""PPO specific config."""
|
"""PPO specific config."""
|
||||||
|
|
||||||
value_clip: bool = False
|
value_clip: bool = False
|
||||||
@ -166,9 +166,7 @@ class PPOConfig:
|
|||||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
general_config: RLAgentConfig,
|
config: PPOConfig,
|
||||||
pg_config: PGConfig,
|
|
||||||
ppo_config: PPOConfig,
|
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
@ -181,9 +179,7 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
|||||||
self.optimizer_factory = optimizer_factory
|
self.optimizer_factory = optimizer_factory
|
||||||
self.critic_factory = critic_factory
|
self.critic_factory = critic_factory
|
||||||
self.actor_factory = actor_factory
|
self.actor_factory = actor_factory
|
||||||
self.ppo_config = ppo_config
|
self.config = config
|
||||||
self.pg_config = pg_config
|
|
||||||
self.general_config = general_config
|
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
self.lr_scheduler_factory = lr_scheduler_factory
|
self.lr_scheduler_factory = lr_scheduler_factory
|
||||||
self.dist_fn = dist_fn
|
self.dist_fn = dist_fn
|
||||||
@ -208,27 +204,30 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
|||||||
action_space=envs.get_action_space(),
|
action_space=envs.get_action_space(),
|
||||||
action_scaling=True,
|
action_scaling=True,
|
||||||
# general_config
|
# general_config
|
||||||
discount_factor=self.general_config.gamma,
|
discount_factor=self.config.gamma,
|
||||||
gae_lambda=self.general_config.gae_lambda,
|
gae_lambda=self.config.gae_lambda,
|
||||||
reward_normalization=self.general_config.rew_norm,
|
reward_normalization=self.config.rew_norm,
|
||||||
action_bound_method=self.general_config.action_bound_method,
|
action_bound_method=self.config.action_bound_method,
|
||||||
# pg_config
|
# pg_config
|
||||||
max_grad_norm=self.pg_config.max_grad_norm,
|
max_grad_norm=self.config.max_grad_norm,
|
||||||
vf_coef=self.pg_config.vf_coef,
|
vf_coef=self.config.vf_coef,
|
||||||
ent_coef=self.pg_config.ent_coef,
|
ent_coef=self.config.ent_coef,
|
||||||
# ppo_config
|
# ppo_config
|
||||||
eps_clip=self.ppo_config.eps_clip,
|
eps_clip=self.config.eps_clip,
|
||||||
value_clip=self.ppo_config.value_clip,
|
value_clip=self.config.value_clip,
|
||||||
dual_clip=self.ppo_config.dual_clip,
|
dual_clip=self.config.dual_clip,
|
||||||
advantage_normalization=self.ppo_config.norm_adv,
|
advantage_normalization=self.config.norm_adv,
|
||||||
recompute_advantage=self.ppo_config.recompute_adv,
|
recompute_advantage=self.config.recompute_adv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoAlphaFactory(ABC):
|
class AutoAlphaFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_auto_alpha(
|
def create_auto_alpha(
|
||||||
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
device: TDevice,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -238,7 +237,10 @@ class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
|
|||||||
self.lr = lr
|
self.lr = lr
|
||||||
|
|
||||||
def create_auto_alpha(
|
def create_auto_alpha(
|
||||||
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
|
self,
|
||||||
|
envs: Environments,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
device: TDevice,
|
||||||
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||||
target_entropy = -np.prod(envs.get_action_shape())
|
target_entropy = -np.prod(envs.get_action_shape())
|
||||||
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user