Use experiment-specific config in mujoco_sac_hl, adding auto-alpha

This commit is contained in:
Dominik Jain 2023-09-20 15:13:05 +02:00
parent adc324038a
commit d26b8cb40c
4 changed files with 92 additions and 8 deletions

View File

@ -7,7 +7,7 @@ from collections.abc import Sequence
from jsonargparse import CLI from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.agent import SACAgentFactory, SACConfig from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
RLExperiment, RLExperiment,
RLExperimentConfig, RLExperimentConfig,
@ -23,17 +23,56 @@ from tianshou.highlevel.optim import AdamOptimizerFactory
def main( def main(
experiment_config: RLExperimentConfig, experiment_config: RLExperimentConfig,
sampling_config: RLSamplingConfig, task: str = "Ant-v3",
sac_config: SACConfig, buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256), hidden_sizes: Sequence[int] = (256, 256),
task: str = "Ant-v4", actor_lr: float = 1e-3,
critic_lr: float = 1e-3,
gamma: float = 0.99,
tau: float = 0.005,
alpha: float = 0.2,
auto_alpha: bool = False,
alpha_lr: float = 3e-4,
start_timesteps: int = 10000,
epoch: int = 200,
step_per_epoch: int = 5000,
step_per_collect: int = 1,
update_per_step: int = 1,
n_step: int = 1,
batch_size: int = 256,
training_num: int = 1,
test_num: int = 10,
): ):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "sac", str(experiment_config.seed), now) log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory() logger_factory = DefaultLoggerFactory()
sampling_config = RLSamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
batch_size=batch_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
start_timesteps=start_timesteps,
start_timesteps_random=True,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
if auto_alpha:
alpha = DefaultAutoAlphaFactory(lr=alpha_lr)
sac_config = SACConfig(
tau=tau,
gamma=gamma,
alpha=alpha,
estimation_step=n_step,
actor_lr=actor_lr,
critic1_lr=critic_lr,
critic2_lr=critic_lr,
)
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True) actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
critic_factory = ContinuousNetCriticFactory(hidden_sizes) critic_factory = ContinuousNetCriticFactory(hidden_sizes)
optim_factory = AdamOptimizerFactory() optim_factory = AdamOptimizerFactory()

View File

@ -4,6 +4,7 @@ from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import Literal
import numpy as np
import torch import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
@ -34,6 +35,8 @@ class AgentFactory(ABC):
buffer = ReplayBuffer(buffer_size) buffer = ReplayBuffer(buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs) test_collector = Collector(policy, envs.test_envs)
if self.sampling_config.start_timesteps > 0:
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
return train_collector, test_collector return train_collector, test_collector
@abstractmethod @abstractmethod
@ -222,10 +225,32 @@ class PPOAgentFactory(OnpolicyAgentFactory):
) )
class AutoAlphaFactory(ABC):
@abstractmethod
def create_auto_alpha(
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
):
pass
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4):
self.lr = lr
def create_auto_alpha(
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim
@dataclass
class SACConfig: class SACConfig:
tau: float = 0.005 tau: float = 0.005
gamma: float = 0.99 gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2 alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
reward_normalization: bool = False reward_normalization: bool = False
estimation_step: int = 1 estimation_step: int = 1
deterministic_eval: bool = True deterministic_eval: bool = True
@ -260,6 +285,10 @@ class SACAgentFactory(OffpolicyAgentFactory):
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr) actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr) critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr) critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
if isinstance(self.config.alpha, AutoAlphaFactory):
alpha = self.config.alpha.create_auto_alpha(envs, self.optim_factory, device)
else:
alpha = self.config.alpha
return SACPolicy( return SACPolicy(
actor, actor,
actor_optim, actor_optim,
@ -269,7 +298,7 @@ class SACAgentFactory(OffpolicyAgentFactory):
critic2_optim, critic2_optim,
tau=self.config.tau, tau=self.config.tau,
gamma=self.config.gamma, gamma=self.config.gamma,
alpha=self.config.alpha, alpha=alpha,
estimation_step=self.config.estimation_step, estimation_step=self.config.estimation_step,
action_space=envs.get_action_space(), action_space=envs.get_action_space(),
deterministic_eval=self.config.deterministic_eval, deterministic_eval=self.config.deterministic_eval,

View File

@ -46,6 +46,8 @@ class RLSamplingConfig:
step_per_collect: int = 2048 step_per_collect: int = 2048
repeat_per_collect: int = 10 repeat_per_collect: int = 10
update_per_step: int = 1 update_per_step: int = 1
start_timesteps: int = 0
start_timesteps_random: bool = False
class RLExperiment(Generic[TPolicy, TTrainer]): class RLExperiment(Generic[TPolicy, TTrainer]):

View File

@ -23,7 +23,12 @@ class LoggerFactory(ABC):
class DefaultLoggerFactory(LoggerFactory): class DefaultLoggerFactory(LoggerFactory):
def __init__(self, log_dir: str = "log", logger_type: Literal["tensorboard", "wandb"] = "tensorboard", wandb_project: str | None = None): def __init__(
self,
log_dir: str = "log",
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
wandb_project: str | None = None,
):
if logger_type == "wandb" and wandb_project is None: if logger_type == "wandb" and wandb_project is None:
raise ValueError("Must provide 'wand_project'") raise ValueError("Must provide 'wand_project'")
self.log_dir = log_dir self.log_dir = log_dir
@ -32,7 +37,16 @@ class DefaultLoggerFactory(LoggerFactory):
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger: def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
writer = SummaryWriter(self.log_dir) writer = SummaryWriter(self.log_dir)
writer.add_text("args", str(dict(log_dir=self.log_dir, logger_type=self.logger_type, wandb_project=self.wandb_project))) writer.add_text(
"args",
str(
dict(
log_dir=self.log_dir,
logger_type=self.logger_type,
wandb_project=self.wandb_project,
),
),
)
if self.logger_type == "wandb": if self.logger_type == "wandb":
logger = WandbLogger( logger = WandbLogger(
save_interval=1, save_interval=1,