Use experiment-specific config in mujoco_sac_hl, adding auto-alpha
This commit is contained in:
parent
adc324038a
commit
d26b8cb40c
@ -7,7 +7,7 @@ from collections.abc import Sequence
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.agent import SACAgentFactory, SACConfig
|
from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
RLExperiment,
|
RLExperiment,
|
||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
@ -23,17 +23,56 @@ from tianshou.highlevel.optim import AdamOptimizerFactory
|
|||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: RLExperimentConfig,
|
||||||
sampling_config: RLSamplingConfig,
|
task: str = "Ant-v3",
|
||||||
sac_config: SACConfig,
|
buffer_size: int = 1000000,
|
||||||
hidden_sizes: Sequence[int] = (256, 256),
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
task: str = "Ant-v4",
|
actor_lr: float = 1e-3,
|
||||||
|
critic_lr: float = 1e-3,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
tau: float = 0.005,
|
||||||
|
alpha: float = 0.2,
|
||||||
|
auto_alpha: bool = False,
|
||||||
|
alpha_lr: float = 3e-4,
|
||||||
|
start_timesteps: int = 10000,
|
||||||
|
epoch: int = 200,
|
||||||
|
step_per_epoch: int = 5000,
|
||||||
|
step_per_collect: int = 1,
|
||||||
|
update_per_step: int = 1,
|
||||||
|
n_step: int = 1,
|
||||||
|
batch_size: int = 256,
|
||||||
|
training_num: int = 1,
|
||||||
|
test_num: int = 10,
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
||||||
logger_factory = DefaultLoggerFactory()
|
logger_factory = DefaultLoggerFactory()
|
||||||
|
|
||||||
|
sampling_config = RLSamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
batch_size=batch_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
update_per_step=update_per_step,
|
||||||
|
start_timesteps=start_timesteps,
|
||||||
|
start_timesteps_random=True,
|
||||||
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
|
if auto_alpha:
|
||||||
|
alpha = DefaultAutoAlphaFactory(lr=alpha_lr)
|
||||||
|
sac_config = SACConfig(
|
||||||
|
tau=tau,
|
||||||
|
gamma=gamma,
|
||||||
|
alpha=alpha,
|
||||||
|
estimation_step=n_step,
|
||||||
|
actor_lr=actor_lr,
|
||||||
|
critic1_lr=critic_lr,
|
||||||
|
critic2_lr=critic_lr,
|
||||||
|
)
|
||||||
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
|
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
|
||||||
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
||||||
optim_factory = AdamOptimizerFactory()
|
optim_factory = AdamOptimizerFactory()
|
||||||
|
@ -4,6 +4,7 @@ from collections.abc import Callable
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
@ -34,6 +35,8 @@ class AgentFactory(ABC):
|
|||||||
buffer = ReplayBuffer(buffer_size)
|
buffer = ReplayBuffer(buffer_size)
|
||||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
test_collector = Collector(policy, envs.test_envs)
|
test_collector = Collector(policy, envs.test_envs)
|
||||||
|
if self.sampling_config.start_timesteps > 0:
|
||||||
|
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
|
||||||
return train_collector, test_collector
|
return train_collector, test_collector
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -222,10 +225,32 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoAlphaFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_auto_alpha(
|
||||||
|
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
|
||||||
|
def __init__(self, lr: float = 3e-4):
|
||||||
|
self.lr = lr
|
||||||
|
|
||||||
|
def create_auto_alpha(
|
||||||
|
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
|
||||||
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||||
|
target_entropy = -np.prod(envs.get_action_shape())
|
||||||
|
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||||
|
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
||||||
|
return target_entropy, log_alpha, alpha_optim
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class SACConfig:
|
class SACConfig:
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
gamma: float = 0.99
|
gamma: float = 0.99
|
||||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
||||||
reward_normalization: bool = False
|
reward_normalization: bool = False
|
||||||
estimation_step: int = 1
|
estimation_step: int = 1
|
||||||
deterministic_eval: bool = True
|
deterministic_eval: bool = True
|
||||||
@ -260,6 +285,10 @@ class SACAgentFactory(OffpolicyAgentFactory):
|
|||||||
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
|
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
|
||||||
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
|
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
|
||||||
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
|
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
|
||||||
|
if isinstance(self.config.alpha, AutoAlphaFactory):
|
||||||
|
alpha = self.config.alpha.create_auto_alpha(envs, self.optim_factory, device)
|
||||||
|
else:
|
||||||
|
alpha = self.config.alpha
|
||||||
return SACPolicy(
|
return SACPolicy(
|
||||||
actor,
|
actor,
|
||||||
actor_optim,
|
actor_optim,
|
||||||
@ -269,7 +298,7 @@ class SACAgentFactory(OffpolicyAgentFactory):
|
|||||||
critic2_optim,
|
critic2_optim,
|
||||||
tau=self.config.tau,
|
tau=self.config.tau,
|
||||||
gamma=self.config.gamma,
|
gamma=self.config.gamma,
|
||||||
alpha=self.config.alpha,
|
alpha=alpha,
|
||||||
estimation_step=self.config.estimation_step,
|
estimation_step=self.config.estimation_step,
|
||||||
action_space=envs.get_action_space(),
|
action_space=envs.get_action_space(),
|
||||||
deterministic_eval=self.config.deterministic_eval,
|
deterministic_eval=self.config.deterministic_eval,
|
||||||
|
@ -46,6 +46,8 @@ class RLSamplingConfig:
|
|||||||
step_per_collect: int = 2048
|
step_per_collect: int = 2048
|
||||||
repeat_per_collect: int = 10
|
repeat_per_collect: int = 10
|
||||||
update_per_step: int = 1
|
update_per_step: int = 1
|
||||||
|
start_timesteps: int = 0
|
||||||
|
start_timesteps_random: bool = False
|
||||||
|
|
||||||
|
|
||||||
class RLExperiment(Generic[TPolicy, TTrainer]):
|
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||||
|
@ -23,7 +23,12 @@ class LoggerFactory(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultLoggerFactory(LoggerFactory):
|
class DefaultLoggerFactory(LoggerFactory):
|
||||||
def __init__(self, log_dir: str = "log", logger_type: Literal["tensorboard", "wandb"] = "tensorboard", wandb_project: str | None = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
log_dir: str = "log",
|
||||||
|
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
|
||||||
|
wandb_project: str | None = None,
|
||||||
|
):
|
||||||
if logger_type == "wandb" and wandb_project is None:
|
if logger_type == "wandb" and wandb_project is None:
|
||||||
raise ValueError("Must provide 'wand_project'")
|
raise ValueError("Must provide 'wand_project'")
|
||||||
self.log_dir = log_dir
|
self.log_dir = log_dir
|
||||||
@ -32,7 +37,16 @@ class DefaultLoggerFactory(LoggerFactory):
|
|||||||
|
|
||||||
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
||||||
writer = SummaryWriter(self.log_dir)
|
writer = SummaryWriter(self.log_dir)
|
||||||
writer.add_text("args", str(dict(log_dir=self.log_dir, logger_type=self.logger_type, wandb_project=self.wandb_project)))
|
writer.add_text(
|
||||||
|
"args",
|
||||||
|
str(
|
||||||
|
dict(
|
||||||
|
log_dir=self.log_dir,
|
||||||
|
logger_type=self.logger_type,
|
||||||
|
wandb_project=self.wandb_project,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
if self.logger_type == "wandb":
|
if self.logger_type == "wandb":
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
save_interval=1,
|
save_interval=1,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user