Add SAC high-level interface
This commit is contained in:
parent
2a1cc6bb55
commit
316eb3c579
@ -2,9 +2,9 @@ import warnings
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.config import RLSamplingConfig, BasicExperimentConfig
|
from tianshou.config import BasicExperimentConfig, RLSamplingConfig
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||||
from tianshou.highlevel.env import EnvFactory, Environments, ContinuousEnvironments
|
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
@ -12,9 +12,7 @@ except ImportError:
|
|||||||
envpool = None
|
envpool = None
|
||||||
|
|
||||||
|
|
||||||
def make_mujoco_env(
|
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
|
||||||
task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool
|
|
||||||
):
|
|
||||||
"""Wrapper function for Mujoco env.
|
"""Wrapper function for Mujoco env.
|
||||||
|
|
||||||
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
|
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
|
||||||
|
|||||||
@ -1,359 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import os
|
|
||||||
import pprint
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Literal, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from jsonargparse import CLI
|
|
||||||
from torch import nn
|
|
||||||
from torch.distributions import Independent, Normal
|
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
from mujoco_env import make_mujoco_env
|
|
||||||
from tianshou.config import (
|
|
||||||
BasicExperimentConfig,
|
|
||||||
LoggerConfig,
|
|
||||||
NNConfig,
|
|
||||||
PGConfig,
|
|
||||||
PPOConfig,
|
|
||||||
RLAgentConfig,
|
|
||||||
RLSamplingConfig,
|
|
||||||
)
|
|
||||||
from tianshou.config.utils import collect_configs
|
|
||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
|
||||||
from tianshou.env import VectorEnvNormObs
|
|
||||||
from tianshou.policy import BasePolicy, PPOPolicy
|
|
||||||
from tianshou.trainer import OnpolicyTrainer
|
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
|
||||||
from tianshou.utils.net.common import ActorCritic, Net
|
|
||||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed=42):
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger_for_run(
|
|
||||||
algo_name: str,
|
|
||||||
task: str,
|
|
||||||
logger_config: LoggerConfig,
|
|
||||||
config: dict,
|
|
||||||
seed: int,
|
|
||||||
resume_id: Optional[Union[str, int]],
|
|
||||||
) -> Tuple[str, Union[WandbLogger, TensorboardLogger]]:
|
|
||||||
"""
|
|
||||||
|
|
||||||
:param algo_name:
|
|
||||||
:param task:
|
|
||||||
:param logger_config:
|
|
||||||
:param config: the experiment config
|
|
||||||
:param seed:
|
|
||||||
:param resume_id: used as run_id by wandb, unused for tensorboard
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
"""Returns the log_path and logger."""
|
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
|
||||||
log_name = os.path.join(task, algo_name, str(seed), now)
|
|
||||||
log_path = os.path.join(logger_config.logdir, log_name)
|
|
||||||
|
|
||||||
logger = get_logger(
|
|
||||||
logger_config.logger,
|
|
||||||
log_path,
|
|
||||||
log_name=log_name,
|
|
||||||
run_id=resume_id,
|
|
||||||
config=config,
|
|
||||||
wandb_project=logger_config.wandb_project,
|
|
||||||
)
|
|
||||||
return log_path, logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_continuous_env_info(
|
|
||||||
env: gym.Env,
|
|
||||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
|
|
||||||
if not isinstance(env.action_space, gym.spaces.Box):
|
|
||||||
raise ValueError(
|
|
||||||
"Only environments with continuous action space are supported here. "
|
|
||||||
f"But got env with action space: {env.action_space.__class__}."
|
|
||||||
)
|
|
||||||
state_shape = env.observation_space.shape or env.observation_space.n
|
|
||||||
if not state_shape:
|
|
||||||
raise ValueError("Observation space shape is not defined")
|
|
||||||
action_shape = env.action_space.shape
|
|
||||||
max_action = env.action_space.high[0]
|
|
||||||
return state_shape, action_shape, max_action
|
|
||||||
|
|
||||||
|
|
||||||
def resume_from_checkpoint(
|
|
||||||
path: str,
|
|
||||||
policy: BasePolicy,
|
|
||||||
train_envs: VectorEnvNormObs | None = None,
|
|
||||||
test_envs: VectorEnvNormObs | None = None,
|
|
||||||
device: str | int | torch.device | None = None,
|
|
||||||
):
|
|
||||||
ckpt = torch.load(path, map_location=device)
|
|
||||||
policy.load_state_dict(ckpt["model"])
|
|
||||||
if train_envs:
|
|
||||||
train_envs.set_obs_rms(ckpt["obs_rms"])
|
|
||||||
if test_envs:
|
|
||||||
test_envs.set_obs_rms(ckpt["obs_rms"])
|
|
||||||
print("Loaded agent and obs. running means from: ", path)
|
|
||||||
|
|
||||||
|
|
||||||
def watch_agent(n_episode, policy: BasePolicy, test_collector: Collector, render=0.0):
|
|
||||||
policy.eval()
|
|
||||||
test_collector.reset()
|
|
||||||
result = test_collector.collect(n_episode=n_episode, render=render)
|
|
||||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_test_collector(
|
|
||||||
buffer_size: int,
|
|
||||||
policy: BasePolicy,
|
|
||||||
train_envs: VectorEnvNormObs,
|
|
||||||
test_envs: VectorEnvNormObs,
|
|
||||||
):
|
|
||||||
if len(train_envs) > 1:
|
|
||||||
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
|
||||||
else:
|
|
||||||
buffer = ReplayBuffer(buffer_size)
|
|
||||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
|
||||||
test_collector = Collector(policy, test_envs)
|
|
||||||
return test_collector, train_collector
|
|
||||||
|
|
||||||
|
|
||||||
TShape = Union[int, Sequence[int]]
|
|
||||||
|
|
||||||
|
|
||||||
def get_actor_critic(
|
|
||||||
state_shape: TShape,
|
|
||||||
hidden_sizes: Sequence[int],
|
|
||||||
action_shape: TShape,
|
|
||||||
device: str | int | torch.device = "cpu",
|
|
||||||
):
|
|
||||||
net_a = Net(
|
|
||||||
state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device
|
|
||||||
)
|
|
||||||
actor = ActorProb(net_a, action_shape, unbounded=True, device=device).to(device)
|
|
||||||
net_c = Net(
|
|
||||||
state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device
|
|
||||||
)
|
|
||||||
# TODO: twice device?
|
|
||||||
critic = Critic(net_c, device=device).to(device)
|
|
||||||
return actor, critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger(
|
|
||||||
kind: Literal["wandb", "tensorboard"],
|
|
||||||
log_path: str,
|
|
||||||
log_name="",
|
|
||||||
run_id: Optional[Union[str, int]] = None,
|
|
||||||
config: Optional[Union[dict, argparse.Namespace]] = None,
|
|
||||||
wandb_project: Optional[str] = None,
|
|
||||||
):
|
|
||||||
writer = SummaryWriter(log_path)
|
|
||||||
writer.add_text("args", str(config))
|
|
||||||
if kind == "wandb":
|
|
||||||
logger = WandbLogger(
|
|
||||||
save_interval=1,
|
|
||||||
name=log_name.replace(os.path.sep, "__"),
|
|
||||||
run_id=run_id,
|
|
||||||
config=config,
|
|
||||||
project=wandb_project,
|
|
||||||
)
|
|
||||||
logger.load(writer)
|
|
||||||
elif kind == "tensorboard":
|
|
||||||
logger = TensorboardLogger(writer)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown logger: {kind}")
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_lr_scheduler(optim, step_per_epoch: int, step_per_collect: int, epochs: int):
|
|
||||||
"""Decay learning rate to 0 linearly."""
|
|
||||||
max_update_num = np.ceil(step_per_epoch / step_per_collect) * epochs
|
|
||||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
|
||||||
return lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def init_and_get_optim(actor: nn.Module, critic: nn.Module, lr: float):
|
|
||||||
"""Initializes layers of actor and critic.
|
|
||||||
|
|
||||||
:param actor:
|
|
||||||
:param critic:
|
|
||||||
:param lr:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
actor_critic = ActorCritic(actor, critic)
|
|
||||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
|
||||||
for m in actor_critic.modules():
|
|
||||||
if isinstance(m, torch.nn.Linear):
|
|
||||||
# orthogonal initialization
|
|
||||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
|
||||||
torch.nn.init.zeros_(m.bias)
|
|
||||||
if hasattr(actor, "mu"):
|
|
||||||
# For continuous action spaces with Gaussian policies
|
|
||||||
# do last policy layer scaling, this will make initial actions have (close to)
|
|
||||||
# 0 mean and std, and will help boost performances,
|
|
||||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
|
||||||
for m in actor.mu.modules():
|
|
||||||
# TODO: seems like biases are initialized twice for the actor
|
|
||||||
if isinstance(m, torch.nn.Linear):
|
|
||||||
torch.nn.init.zeros_(m.bias)
|
|
||||||
m.weight.data.copy_(0.01 * m.weight.data)
|
|
||||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=lr)
|
|
||||||
return optim
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
|
||||||
experiment_config: BasicExperimentConfig,
|
|
||||||
logger_config: LoggerConfig,
|
|
||||||
sampling_config: RLSamplingConfig,
|
|
||||||
general_config: RLAgentConfig,
|
|
||||||
pg_config: PGConfig,
|
|
||||||
ppo_config: PPOConfig,
|
|
||||||
nn_config: NNConfig,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Run the PPO test on the provided parameters.
|
|
||||||
|
|
||||||
:param experiment_config: BasicExperimentConfig - not ML or RL specific
|
|
||||||
:param logger_config: LoggerConfig
|
|
||||||
:param sampling_config: SamplingConfig -
|
|
||||||
sampling, epochs, parallelization, buffers, collectors, and batching.
|
|
||||||
:param general_config: RLAgentConfig - general RL agent config
|
|
||||||
:param pg_config: PGConfig: common to most policy gradient algorithms
|
|
||||||
:param ppo_config: PPOConfig - PPO specific config
|
|
||||||
:param nn_config: NNConfig - NN-training specific config
|
|
||||||
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
full_config = collect_configs(*locals().values())
|
|
||||||
set_seed(experiment_config.seed)
|
|
||||||
|
|
||||||
# create test and train envs, add env info to config
|
|
||||||
env, train_envs, test_envs = make_mujoco_env(
|
|
||||||
task=experiment_config.task,
|
|
||||||
seed=experiment_config.seed,
|
|
||||||
num_train_envs=sampling_config.num_train_envs,
|
|
||||||
num_test_envs=sampling_config.num_test_envs,
|
|
||||||
obs_norm=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# adding env_info to logged config
|
|
||||||
state_shape, action_shape, max_action = get_continuous_env_info(env)
|
|
||||||
full_config["env_info"] = {
|
|
||||||
"state_shape": state_shape,
|
|
||||||
"action_shape": action_shape,
|
|
||||||
"max_action": max_action,
|
|
||||||
}
|
|
||||||
log_path, logger = get_logger_for_run(
|
|
||||||
"ppo",
|
|
||||||
experiment_config.task,
|
|
||||||
logger_config,
|
|
||||||
full_config,
|
|
||||||
experiment_config.seed,
|
|
||||||
experiment_config.resume_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup NNs
|
|
||||||
actor, critic = get_actor_critic(
|
|
||||||
state_shape, nn_config.hidden_sizes, action_shape, experiment_config.device
|
|
||||||
)
|
|
||||||
optim = init_and_get_optim(actor, critic, nn_config.lr)
|
|
||||||
|
|
||||||
lr_scheduler = None
|
|
||||||
if nn_config.lr_decay:
|
|
||||||
lr_scheduler = get_lr_scheduler(
|
|
||||||
optim,
|
|
||||||
sampling_config.step_per_epoch,
|
|
||||||
sampling_config.step_per_collect,
|
|
||||||
sampling_config.num_epochs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create policy
|
|
||||||
def dist_fn(*logits):
|
|
||||||
return Independent(Normal(*logits), 1)
|
|
||||||
|
|
||||||
policy = PPOPolicy(
|
|
||||||
# nn-stuff
|
|
||||||
actor,
|
|
||||||
critic,
|
|
||||||
optim,
|
|
||||||
dist_fn=dist_fn,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
# env-stuff
|
|
||||||
action_space=train_envs.action_space,
|
|
||||||
action_scaling=True,
|
|
||||||
# general_config
|
|
||||||
discount_factor=general_config.gamma,
|
|
||||||
gae_lambda=general_config.gae_lambda,
|
|
||||||
reward_normalization=general_config.rew_norm,
|
|
||||||
action_bound_method=general_config.action_bound_method,
|
|
||||||
# pg_config
|
|
||||||
max_grad_norm=pg_config.max_grad_norm,
|
|
||||||
vf_coef=pg_config.vf_coef,
|
|
||||||
ent_coef=pg_config.ent_coef,
|
|
||||||
# ppo_config
|
|
||||||
eps_clip=ppo_config.eps_clip,
|
|
||||||
value_clip=ppo_config.value_clip,
|
|
||||||
dual_clip=ppo_config.dual_clip,
|
|
||||||
advantage_normalization=ppo_config.norm_adv,
|
|
||||||
recompute_advantage=ppo_config.recompute_adv,
|
|
||||||
)
|
|
||||||
|
|
||||||
if experiment_config.resume_path:
|
|
||||||
resume_from_checkpoint(
|
|
||||||
experiment_config.resume_path,
|
|
||||||
policy,
|
|
||||||
train_envs=train_envs,
|
|
||||||
test_envs=test_envs,
|
|
||||||
device=experiment_config.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_collector, train_collector = get_train_test_collector(
|
|
||||||
sampling_config.buffer_size, policy, test_envs, train_envs
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: test num is the number of test envs but used as episode_per_test
|
|
||||||
# here and in watch_agent
|
|
||||||
if not experiment_config.watch:
|
|
||||||
# RL training
|
|
||||||
def save_best_fn(pol: nn.Module):
|
|
||||||
state = {"model": pol.state_dict(), "obs_rms": train_envs.get_obs_rms()}
|
|
||||||
torch.save(state, os.path.join(log_path, "policy.pth"))
|
|
||||||
|
|
||||||
trainer = OnpolicyTrainer(
|
|
||||||
policy=policy,
|
|
||||||
train_collector=train_collector,
|
|
||||||
test_collector=test_collector,
|
|
||||||
max_epoch=sampling_config.num_epochs,
|
|
||||||
step_per_epoch=sampling_config.step_per_epoch,
|
|
||||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
|
||||||
episode_per_test=sampling_config.num_test_envs,
|
|
||||||
batch_size=sampling_config.batch_size,
|
|
||||||
step_per_collect=sampling_config.step_per_collect,
|
|
||||||
save_best_fn=save_best_fn,
|
|
||||||
logger=logger,
|
|
||||||
test_in_train=False,
|
|
||||||
)
|
|
||||||
result = trainer.run()
|
|
||||||
pprint.pprint(result)
|
|
||||||
|
|
||||||
watch_agent(
|
|
||||||
sampling_config.num_test_envs,
|
|
||||||
policy,
|
|
||||||
test_collector,
|
|
||||||
render=experiment_config.render,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
CLI(main)
|
|
||||||
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
@ -10,17 +12,26 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory
|
|||||||
from tianshou.config import (
|
from tianshou.config import (
|
||||||
BasicExperimentConfig,
|
BasicExperimentConfig,
|
||||||
LoggerConfig,
|
LoggerConfig,
|
||||||
NNConfig,
|
|
||||||
PGConfig,
|
PGConfig,
|
||||||
PPOConfig,
|
PPOConfig,
|
||||||
RLAgentConfig,
|
RLAgentConfig,
|
||||||
RLSamplingConfig,
|
RLSamplingConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.agent import PPOAgentFactory
|
from tianshou.highlevel.agent import PPOAgentFactory
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
|
||||||
from tianshou.highlevel.module import ContinuousActorProbFactory, ContinuousNetCriticFactory
|
|
||||||
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
|
||||||
from tianshou.highlevel.experiment import RLExperiment
|
from tianshou.highlevel.experiment import RLExperiment
|
||||||
|
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||||
|
from tianshou.highlevel.module import (
|
||||||
|
ContinuousActorProbFactory,
|
||||||
|
ContinuousNetCriticFactory,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NNConfig:
|
||||||
|
hidden_sizes: Sequence[int] = (64, 64)
|
||||||
|
lr: float = 3e-4
|
||||||
|
lr_decay: bool = True
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -43,15 +54,22 @@ def main(
|
|||||||
|
|
||||||
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
|
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
|
||||||
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
|
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
|
||||||
optim_factory = AdamOptimizerFactory(lr=nn_config.lr)
|
optim_factory = AdamOptimizerFactory()
|
||||||
lr_scheduler_factory = LinearLRSchedulerFactory(nn_config, sampling_config)
|
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if nn_config.lr_decay else None
|
||||||
agent_factory = PPOAgentFactory(general_config, pg_config, ppo_config, sampling_config, nn_config,
|
agent_factory = PPOAgentFactory(
|
||||||
actor_factory, critic_factory, optim_factory, dist_fn, lr_scheduler_factory)
|
general_config,
|
||||||
|
pg_config,
|
||||||
|
ppo_config,
|
||||||
|
sampling_config,
|
||||||
|
actor_factory,
|
||||||
|
critic_factory,
|
||||||
|
optim_factory,
|
||||||
|
dist_fn,
|
||||||
|
nn_config.lr,
|
||||||
|
lr_scheduler_factory,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = RLExperiment(experiment_config, logger_config, general_config, sampling_config,
|
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
|
||||||
env_factory,
|
|
||||||
logger_factory,
|
|
||||||
agent_factory)
|
|
||||||
|
|
||||||
experiment.run(log_name)
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|||||||
56
examples/mujoco/mujoco_sac_hl.py
Normal file
56
examples/mujoco/mujoco_sac_hl.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
|
from tianshou.config import (
|
||||||
|
BasicExperimentConfig,
|
||||||
|
LoggerConfig,
|
||||||
|
RLSamplingConfig,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.agent import SACAgentFactory
|
||||||
|
from tianshou.highlevel.experiment import RLExperiment
|
||||||
|
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||||
|
from tianshou.highlevel.module import (
|
||||||
|
ContinuousActorProbFactory,
|
||||||
|
ContinuousNetCriticFactory,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.optim import AdamOptimizerFactory
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: BasicExperimentConfig,
|
||||||
|
logger_config: LoggerConfig,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
sac_config: SACAgentFactory.Config,
|
||||||
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
|
):
|
||||||
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now)
|
||||||
|
logger_factory = DefaultLoggerFactory(logger_config)
|
||||||
|
|
||||||
|
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
|
||||||
|
|
||||||
|
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
|
||||||
|
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
||||||
|
optim_factory = AdamOptimizerFactory()
|
||||||
|
agent_factory = SACAgentFactory(
|
||||||
|
sac_config,
|
||||||
|
sampling_config,
|
||||||
|
actor_factory,
|
||||||
|
critic_factory,
|
||||||
|
critic_factory,
|
||||||
|
optim_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
|
||||||
|
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
CLI(main)
|
||||||
@ -63,7 +63,9 @@ envpool = ["envpool"]
|
|||||||
optional = true
|
optional = true
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
black = "^23.7.0"
|
black = "^23.7.0"
|
||||||
|
docstring-parser = "^0.15"
|
||||||
jinja2 = "*"
|
jinja2 = "*"
|
||||||
|
jsonargparse = "^4.24.1"
|
||||||
mypy = "^1.4.1"
|
mypy = "^1.4.1"
|
||||||
# networkx is used in a test
|
# networkx is used in a test
|
||||||
networkx = "*"
|
networkx = "*"
|
||||||
@ -83,8 +85,6 @@ sphinx_rtd_theme = "*"
|
|||||||
sphinxcontrib-bibtex = "*"
|
sphinxcontrib-bibtex = "*"
|
||||||
sphinxcontrib-spelling = "^8.0.0"
|
sphinxcontrib-spelling = "^8.0.0"
|
||||||
wandb = "^0.12.0"
|
wandb = "^0.12.0"
|
||||||
jsonargparse = "^4.24.1"
|
|
||||||
docstring-parser = "^0.15"
|
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
allow_redefinition = true
|
allow_redefinition = true
|
||||||
|
|||||||
@ -1 +1,10 @@
|
|||||||
from .config import *
|
__all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"]
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
BasicExperimentConfig,
|
||||||
|
PGConfig,
|
||||||
|
PPOConfig,
|
||||||
|
RLAgentConfig,
|
||||||
|
RLSamplingConfig,
|
||||||
|
LoggerConfig,
|
||||||
|
)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal, Optional, Sequence
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from jsonargparse import set_docstring_parse_options
|
from jsonargparse import set_docstring_parse_options
|
||||||
@ -14,12 +14,12 @@ class BasicExperimentConfig:
|
|||||||
seed: int = 42
|
seed: int = 42
|
||||||
task: str = "Ant-v4"
|
task: str = "Ant-v4"
|
||||||
"""Mujoco specific"""
|
"""Mujoco specific"""
|
||||||
render: Optional[float] = 0.0
|
render: float | None = 0.0
|
||||||
"""Milliseconds between rendered frames; if None, no rendering"""
|
"""Milliseconds between rendered frames; if None, no rendering"""
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
resume_id: Optional[int] = None
|
resume_id: str | None = None
|
||||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||||
resume_path: str = None
|
resume_path: str | None = None
|
||||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||||
watch: bool = False
|
watch: bool = False
|
||||||
"""If True, will not perform training and only watch the restored policy"""
|
"""If True, will not perform training and only watch the restored policy"""
|
||||||
@ -28,7 +28,7 @@ class BasicExperimentConfig:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoggerConfig:
|
class LoggerConfig:
|
||||||
"""Logging config"""
|
"""Logging config."""
|
||||||
|
|
||||||
logdir: str = "log"
|
logdir: str = "log"
|
||||||
logger: Literal["tensorboard", "wandb"] = "tensorboard"
|
logger: Literal["tensorboard", "wandb"] = "tensorboard"
|
||||||
@ -48,17 +48,18 @@ class RLSamplingConfig:
|
|||||||
buffer_size: int = 4096
|
buffer_size: int = 4096
|
||||||
step_per_collect: int = 2048
|
step_per_collect: int = 2048
|
||||||
repeat_per_collect: int = 10
|
repeat_per_collect: int = 10
|
||||||
|
update_per_step: int = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RLAgentConfig:
|
class RLAgentConfig:
|
||||||
"""Config common to most RL algorithms"""
|
"""Config common to most RL algorithms."""
|
||||||
|
|
||||||
gamma: float = 0.99
|
gamma: float = 0.99
|
||||||
"""Discount factor"""
|
"""Discount factor"""
|
||||||
gae_lambda: float = 0.95
|
gae_lambda: float = 0.95
|
||||||
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
|
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
|
||||||
action_bound_method: Optional[Literal["clip", "tanh"]] = "clip"
|
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
||||||
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
|
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
|
||||||
rew_norm: bool = True
|
rew_norm: bool = True
|
||||||
"""Whether to normalize rewards"""
|
"""Whether to normalize rewards"""
|
||||||
@ -66,7 +67,7 @@ class RLAgentConfig:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGConfig:
|
class PGConfig:
|
||||||
"""Config of general policy-gradient algorithms"""
|
"""Config of general policy-gradient algorithms."""
|
||||||
|
|
||||||
ent_coef: float = 0.0
|
ent_coef: float = 0.0
|
||||||
vf_coef: float = 0.25
|
vf_coef: float = 0.25
|
||||||
@ -75,18 +76,11 @@ class PGConfig:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PPOConfig:
|
class PPOConfig:
|
||||||
"""PPO specific config"""
|
"""PPO specific config."""
|
||||||
|
|
||||||
value_clip: bool = False
|
value_clip: bool = False
|
||||||
norm_adv: bool = False
|
norm_adv: bool = False
|
||||||
"""Whether to normalize advantages"""
|
"""Whether to normalize advantages"""
|
||||||
eps_clip: float = 0.2
|
eps_clip: float = 0.2
|
||||||
dual_clip: Optional[float] = None
|
dual_clip: float | None = None
|
||||||
recompute_adv: bool = True
|
recompute_adv: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NNConfig:
|
|
||||||
hidden_sizes: Sequence[int] = (64, 64)
|
|
||||||
lr: float = 3e-4
|
|
||||||
lr_decay: bool = True
|
|
||||||
|
|||||||
@ -2,11 +2,10 @@ from dataclasses import asdict, is_dataclass
|
|||||||
|
|
||||||
|
|
||||||
def collect_configs(*confs):
|
def collect_configs(*confs):
|
||||||
"""
|
"""Collect instances of dataclasses to a single dict mapping the classname to the values.
|
||||||
Collect instances of dataclasses to a single dict mapping the
|
|
||||||
classname to the values. If any of the passed objects is not a
|
If any of the passed objects is not a ddataclass or if two instances
|
||||||
dataclass or if two instances of the same config class are passed,
|
of the same config class are passed, an error will be raised.
|
||||||
an error will be raised.
|
|
||||||
|
|
||||||
:param confs: dataclasses
|
:param confs: dataclasses
|
||||||
:return: Dictionary mapping class names to their instances.
|
:return: Dictionary mapping class names to their instances.
|
||||||
|
|||||||
@ -1,33 +1,51 @@
|
|||||||
import os
|
import os
|
||||||
from abc import abstractmethod, ABC
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig
|
from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig
|
||||||
from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
|
from tianshou.exploration import BaseNoise
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import Logger
|
||||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||||
from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory
|
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
||||||
from tianshou.policy import BasePolicy, PPOPolicy
|
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
|
||||||
from tianshou.trainer import BaseTrainer, OnpolicyTrainer
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||||
from tianshou.utils.net.common import ActorCritic
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
|
||||||
|
|
||||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||||
|
|
||||||
|
|
||||||
class AgentFactory(ABC):
|
class AgentFactory(ABC):
|
||||||
|
def __init__(self, sampling_config: RLSamplingConfig):
|
||||||
|
self.sampling_config = sampling_config
|
||||||
|
|
||||||
|
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
||||||
|
buffer_size = self.sampling_config.buffer_size
|
||||||
|
train_envs = envs.train_envs
|
||||||
|
if len(train_envs) > 1:
|
||||||
|
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
||||||
|
else:
|
||||||
|
buffer = ReplayBuffer(buffer_size)
|
||||||
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, envs.test_envs)
|
||||||
|
return train_collector, test_collector
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
||||||
def save_best_fn(pol: torch.nn.Module):
|
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||||
state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()}
|
state = {
|
||||||
|
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
||||||
|
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
||||||
|
}
|
||||||
torch.save(state, os.path.join(log_path, "policy.pth"))
|
torch.save(state, os.path.join(log_path, "policy.pth"))
|
||||||
|
|
||||||
return save_best_fn
|
return save_best_fn
|
||||||
@ -43,36 +61,26 @@ class AgentFactory(ABC):
|
|||||||
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_train_test_collector(self,
|
def create_trainer(
|
||||||
policy: BasePolicy,
|
self,
|
||||||
envs: Environments):
|
policy: BasePolicy,
|
||||||
pass
|
train_collector: Collector,
|
||||||
|
test_collector: Collector,
|
||||||
@abstractmethod
|
envs: Environments,
|
||||||
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
logger: Logger,
|
||||||
envs: Environments, logger: Logger) -> BaseTrainer:
|
) -> BaseTrainer:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OnpolicyAgentFactory(AgentFactory, ABC):
|
class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||||
def __init__(self, sampling_config: RLSamplingConfig):
|
def create_trainer(
|
||||||
self.sampling_config = sampling_config
|
self,
|
||||||
|
policy: BasePolicy,
|
||||||
def create_train_test_collector(self,
|
train_collector: Collector,
|
||||||
policy: BasePolicy,
|
test_collector: Collector,
|
||||||
envs: Environments):
|
envs: Environments,
|
||||||
buffer_size = self.sampling_config.buffer_size
|
logger: Logger,
|
||||||
train_envs = envs.train_envs
|
) -> OnpolicyTrainer:
|
||||||
if len(train_envs) > 1:
|
|
||||||
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
|
||||||
else:
|
|
||||||
buffer = ReplayBuffer(buffer_size)
|
|
||||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
|
||||||
test_collector = Collector(policy, envs.test_envs)
|
|
||||||
return train_collector, test_collector
|
|
||||||
|
|
||||||
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
|
||||||
envs: Environments, logger: Logger) -> OnpolicyTrainer:
|
|
||||||
sampling_config = self.sampling_config
|
sampling_config = self.sampling_config
|
||||||
return OnpolicyTrainer(
|
return OnpolicyTrainer(
|
||||||
policy=policy,
|
policy=policy,
|
||||||
@ -90,17 +98,46 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||||
|
def create_trainer(
|
||||||
|
self,
|
||||||
|
policy: BasePolicy,
|
||||||
|
train_collector: Collector,
|
||||||
|
test_collector: Collector,
|
||||||
|
envs: Environments,
|
||||||
|
logger: Logger,
|
||||||
|
) -> OffpolicyTrainer:
|
||||||
|
sampling_config = self.sampling_config
|
||||||
|
return OffpolicyTrainer(
|
||||||
|
policy=policy,
|
||||||
|
train_collector=train_collector,
|
||||||
|
test_collector=test_collector,
|
||||||
|
max_epoch=sampling_config.num_epochs,
|
||||||
|
step_per_epoch=sampling_config.step_per_epoch,
|
||||||
|
step_per_collect=sampling_config.step_per_collect,
|
||||||
|
episode_per_test=sampling_config.num_test_envs,
|
||||||
|
batch_size=sampling_config.batch_size,
|
||||||
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
||||||
|
logger=logger.logger,
|
||||||
|
update_per_step=sampling_config.update_per_step,
|
||||||
|
test_in_train=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||||
def __init__(self, general_config: RLAgentConfig,
|
def __init__(
|
||||||
pg_config: PGConfig,
|
self,
|
||||||
ppo_config: PPOConfig,
|
general_config: RLAgentConfig,
|
||||||
sampling_config: RLSamplingConfig,
|
pg_config: PGConfig,
|
||||||
nn_config: NNConfig,
|
ppo_config: PPOConfig,
|
||||||
actor_factory: ActorFactory,
|
sampling_config: RLSamplingConfig,
|
||||||
critic_factory: CriticFactory,
|
actor_factory: ActorFactory,
|
||||||
optimizer_factory: OptimizerFactory,
|
critic_factory: CriticFactory,
|
||||||
dist_fn,
|
optimizer_factory: OptimizerFactory,
|
||||||
lr_scheduler_factory: LRSchedulerFactory):
|
dist_fn,
|
||||||
|
lr: float,
|
||||||
|
lr_scheduler_factory: LRSchedulerFactory | None = None,
|
||||||
|
):
|
||||||
super().__init__(sampling_config)
|
super().__init__(sampling_config)
|
||||||
self.optimizer_factory = optimizer_factory
|
self.optimizer_factory = optimizer_factory
|
||||||
self.critic_factory = critic_factory
|
self.critic_factory = critic_factory
|
||||||
@ -108,16 +145,19 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
|||||||
self.ppo_config = ppo_config
|
self.ppo_config = ppo_config
|
||||||
self.pg_config = pg_config
|
self.pg_config = pg_config
|
||||||
self.general_config = general_config
|
self.general_config = general_config
|
||||||
|
self.lr = lr
|
||||||
self.lr_scheduler_factory = lr_scheduler_factory
|
self.lr_scheduler_factory = lr_scheduler_factory
|
||||||
self.dist_fn = dist_fn
|
self.dist_fn = dist_fn
|
||||||
self.nn_config = nn_config
|
|
||||||
|
|
||||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
critic = self.critic_factory.create_module(envs, device)
|
critic = self.critic_factory.create_module(envs, device, use_action=False)
|
||||||
actor_critic = ActorCritic(actor, critic)
|
actor_critic = ActorCritic(actor, critic)
|
||||||
optim = self.optimizer_factory.create_optimizer(actor_critic)
|
optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr)
|
||||||
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
|
if self.lr_scheduler_factory is not None:
|
||||||
|
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
|
||||||
|
else:
|
||||||
|
lr_scheduler = None
|
||||||
return PPOPolicy(
|
return PPOPolicy(
|
||||||
# nn-stuff
|
# nn-stuff
|
||||||
actor,
|
actor,
|
||||||
@ -144,3 +184,60 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
|||||||
advantage_normalization=self.ppo_config.norm_adv,
|
advantage_normalization=self.ppo_config.norm_adv,
|
||||||
recompute_advantage=self.ppo_config.recompute_adv,
|
recompute_advantage=self.ppo_config.recompute_adv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SACAgentFactory(OffpolicyAgentFactory):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: "SACAgentFactory.Config",
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic1_factory: CriticFactory,
|
||||||
|
critic2_factory: CriticFactory,
|
||||||
|
optim_factory: OptimizerFactory,
|
||||||
|
exploration_noise: BaseNoise | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(sampling_config)
|
||||||
|
self.critic2_factory = critic2_factory
|
||||||
|
self.critic1_factory = critic1_factory
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.exploration_noise = exploration_noise
|
||||||
|
self.optim_factory = optim_factory
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
|
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
|
||||||
|
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
|
||||||
|
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
|
||||||
|
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
|
||||||
|
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
|
||||||
|
return SACPolicy(
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critic1,
|
||||||
|
critic1_optim,
|
||||||
|
critic2,
|
||||||
|
critic2_optim,
|
||||||
|
tau=self.config.tau,
|
||||||
|
gamma=self.config.gamma,
|
||||||
|
alpha=self.config.alpha,
|
||||||
|
estimation_step=self.config.estimation_step,
|
||||||
|
action_space=envs.get_action_space(),
|
||||||
|
deterministic_eval=self.config.deterministic_eval,
|
||||||
|
exploration_noise=self.exploration_noise,
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
"""SAC configuration."""
|
||||||
|
|
||||||
|
tau: float = 0.005
|
||||||
|
gamma: float = 0.99
|
||||||
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
|
||||||
|
reward_normalization: bool = False
|
||||||
|
estimation_step: int = 1
|
||||||
|
deterministic_eval: bool = True
|
||||||
|
actor_lr: float = 1e-3
|
||||||
|
critic1_lr: float = 1e-3
|
||||||
|
critic2_lr: float = 1e-3
|
||||||
|
|||||||
@ -1,24 +1,22 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Tuple, Optional, Dict, Any, Union, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
|
|
||||||
TShape = Union[int, Sequence[int]]
|
TShape = int | Sequence[int]
|
||||||
|
|
||||||
|
|
||||||
class Environments(ABC):
|
class Environments(ABC):
|
||||||
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.train_envs = train_envs
|
self.train_envs = train_envs
|
||||||
self.test_envs = test_envs
|
self.test_envs = test_envs
|
||||||
|
|
||||||
def info(self) -> Dict[str, Any]:
|
def info(self) -> dict[str, Any]:
|
||||||
return {
|
return {"action_shape": self.get_action_shape(), "state_shape": self.get_state_shape()}
|
||||||
"action_shape": self.get_action_shape(),
|
|
||||||
"state_shape": self.get_state_shape()
|
|
||||||
}
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_action_shape(self) -> TShape:
|
def get_action_shape(self) -> TShape:
|
||||||
@ -33,7 +31,7 @@ class Environments(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ContinuousEnvironments(Environments):
|
class ContinuousEnvironments(Environments):
|
||||||
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
super().__init__(env, train_envs, test_envs)
|
super().__init__(env, train_envs, test_envs)
|
||||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||||
|
|
||||||
@ -44,12 +42,12 @@ class ContinuousEnvironments(Environments):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_continuous_env_info(
|
def _get_continuous_env_info(
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], float]:
|
||||||
if not isinstance(env.action_space, gym.spaces.Box):
|
if not isinstance(env.action_space, gym.spaces.Box):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only environments with continuous action space are supported here. "
|
"Only environments with continuous action space are supported here. "
|
||||||
f"But got env with action space: {env.action_space.__class__}."
|
f"But got env with action space: {env.action_space.__class__}.",
|
||||||
)
|
)
|
||||||
state_shape = env.observation_space.shape or env.observation_space.n
|
state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
if not state_shape:
|
if not state_shape:
|
||||||
|
|||||||
@ -4,7 +4,9 @@ from typing import Generic, TypeVar
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig
|
from tianshou.config import (
|
||||||
|
BasicExperimentConfig,
|
||||||
|
)
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.highlevel.agent import AgentFactory
|
from tianshou.highlevel.agent import AgentFactory
|
||||||
from tianshou.highlevel.env import EnvFactory
|
from tianshou.highlevel.env import EnvFactory
|
||||||
@ -17,23 +19,19 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
|||||||
|
|
||||||
|
|
||||||
class RLExperiment(Generic[TPolicy, TTrainer]):
|
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
config: BasicExperimentConfig,
|
self,
|
||||||
logger_config: LoggerConfig,
|
config: BasicExperimentConfig,
|
||||||
general_config: RLAgentConfig,
|
env_factory: EnvFactory,
|
||||||
sampling_config: RLSamplingConfig,
|
logger_factory: LoggerFactory,
|
||||||
env_factory: EnvFactory,
|
agent_factory: AgentFactory,
|
||||||
logger_factory: LoggerFactory,
|
):
|
||||||
agent_factory: AgentFactory):
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger_config = logger_config
|
|
||||||
self.general_config = general_config
|
|
||||||
self.sampling_config = sampling_config
|
|
||||||
self.env_factory = env_factory
|
self.env_factory = env_factory
|
||||||
self.logger_factory = logger_factory
|
self.logger_factory = logger_factory
|
||||||
self.agent_factory = agent_factory
|
self.agent_factory = agent_factory
|
||||||
|
|
||||||
def _set_seed(self):
|
def _set_seed(self) -> None:
|
||||||
seed = self.config.seed
|
seed = self.config.seed
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@ -43,7 +41,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
|||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
def run(self, log_name: str):
|
def run(self, log_name: str) -> None:
|
||||||
self._set_seed()
|
self._set_seed()
|
||||||
|
|
||||||
envs = self.env_factory.create_envs()
|
envs = self.env_factory.create_envs()
|
||||||
@ -52,25 +50,47 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
|||||||
full_config.update(envs.info())
|
full_config.update(envs.info())
|
||||||
|
|
||||||
run_id = self.config.resume_id
|
run_id = self.config.resume_id
|
||||||
logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config)
|
logger = self.logger_factory.create_logger(
|
||||||
|
log_name=log_name,
|
||||||
|
run_id=run_id,
|
||||||
|
config_dict=full_config,
|
||||||
|
)
|
||||||
|
|
||||||
policy = self.agent_factory.create_policy(envs, self.config.device)
|
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||||
if self.config.resume_path:
|
if self.config.resume_path:
|
||||||
self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device)
|
self.agent_factory.load_checkpoint(
|
||||||
|
policy,
|
||||||
|
self.config.resume_path,
|
||||||
|
envs,
|
||||||
|
self.config.device,
|
||||||
|
)
|
||||||
|
|
||||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs)
|
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
||||||
|
policy,
|
||||||
|
envs,
|
||||||
|
)
|
||||||
|
|
||||||
if not self.config.watch:
|
if not self.config.watch:
|
||||||
trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger)
|
trainer = self.agent_factory.create_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
envs,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
result = trainer.run()
|
result = trainer.run()
|
||||||
pprint(result) # TODO logging
|
pprint(result) # TODO logging
|
||||||
|
|
||||||
self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render)
|
self._watch_agent(
|
||||||
|
self.config.watch_num_episodes,
|
||||||
|
policy,
|
||||||
|
test_collector,
|
||||||
|
self.config.render,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render):
|
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
import os
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union, Optional
|
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.config import LoggerConfig
|
from tianshou.config import LoggerConfig
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
|
|
||||||
|
TLogger = TensorboardLogger | WandbLogger
|
||||||
TLogger = Union[TensorboardLogger, WandbLogger]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -20,7 +18,7 @@ class Logger:
|
|||||||
|
|
||||||
class LoggerFactory(ABC):
|
class LoggerFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +26,7 @@ class DefaultLoggerFactory(LoggerFactory):
|
|||||||
def __init__(self, config: LoggerConfig):
|
def __init__(self, config: LoggerConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
||||||
writer = SummaryWriter(self.config.logdir)
|
writer = SummaryWriter(self.config.logdir)
|
||||||
writer.add_text("args", str(self.config))
|
writer.add_text("args", str(self.config))
|
||||||
if self.config.logger == "wandb":
|
if self.config.logger == "wandb":
|
||||||
|
|||||||
@ -1,24 +1,24 @@
|
|||||||
from abc import abstractmethod, ABC
|
from abc import ABC, abstractmethod
|
||||||
from typing import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic
|
from tianshou.utils.net.continuous import ActorProb
|
||||||
|
from tianshou.utils.net.continuous import Critic as ContinuousCritic
|
||||||
|
|
||||||
TDevice = str | int | torch.device
|
TDevice = str | int | torch.device
|
||||||
|
|
||||||
|
|
||||||
def init_linear_orthogonal(m: torch.nn.Module):
|
def init_linear_orthogonal(module: torch.nn.Module):
|
||||||
"""
|
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
||||||
Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0
|
|
||||||
|
|
||||||
:param m: the module whose submodules are to be processed
|
:param module: the module whose submodules are to be processed
|
||||||
"""
|
"""
|
||||||
for m in m.modules():
|
for m in module.modules():
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
@ -31,9 +31,9 @@ class ActorFactory(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_linear(actor: torch.nn.Module):
|
def _init_linear(actor: torch.nn.Module):
|
||||||
"""
|
"""Initializes linear layers of an actor module using default mechanisms.
|
||||||
Initializes linear layers of an actor module using default mechanisms
|
|
||||||
:param module: the actor module
|
:param module: the actor module.
|
||||||
"""
|
"""
|
||||||
init_linear_orthogonal(actor)
|
init_linear_orthogonal(actor)
|
||||||
if hasattr(actor, "mu"):
|
if hasattr(actor, "mu"):
|
||||||
@ -51,17 +51,29 @@ class ContinuousActorFactory(ActorFactory, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ContinuousActorProbFactory(ContinuousActorFactory):
|
class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.unbounded = unbounded
|
||||||
|
self.conditioned_sigma = conditioned_sigma
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||||
net_a = Net(
|
net_a = Net(
|
||||||
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
|
envs.get_state_shape(),
|
||||||
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
activation=nn.Tanh,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device)
|
actor = ActorProb(
|
||||||
|
net_a,
|
||||||
|
envs.get_action_shape(),
|
||||||
|
unbounded=True,
|
||||||
|
device=device,
|
||||||
|
conditioned_sigma=self.conditioned_sigma,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
# init params
|
# init params
|
||||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
if not self.conditioned_sigma:
|
||||||
|
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||||
self._init_linear(actor)
|
self._init_linear(actor)
|
||||||
|
|
||||||
return actor
|
return actor
|
||||||
@ -69,7 +81,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
|
|||||||
|
|
||||||
class CriticFactory(ABC):
|
class CriticFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -78,12 +90,19 @@ class ContinuousCriticFactory(CriticFactory, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int], action_shape=0):
|
||||||
|
self.action_shape = action_shape
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
|
action_shape = envs.get_action_shape() if use_action else 0
|
||||||
net_c = Net(
|
net_c = Net(
|
||||||
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
|
envs.get_state_shape(),
|
||||||
|
action_shape=action_shape,
|
||||||
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
concat=use_action,
|
||||||
|
activation=nn.Tanh,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
critic = ContinuousCritic(net_c, device=device).to(device)
|
critic = ContinuousCritic(net_c, device=device).to(device)
|
||||||
init_linear_orthogonal(critic)
|
init_linear_orthogonal(critic)
|
||||||
|
|||||||
@ -1,54 +1,51 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Union, Iterable, Dict, Any, Optional
|
from collections.abc import Iterable
|
||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
|
|
||||||
from tianshou.config import RLSamplingConfig, NNConfig
|
from tianshou.config import RLSamplingConfig
|
||||||
|
|
||||||
TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
|
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class OptimizerFactory(ABC):
|
class OptimizerFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TorchOptimizerFactory(OptimizerFactory):
|
class TorchOptimizerFactory(OptimizerFactory):
|
||||||
def __init__(self, optim_class, **kwargs):
|
def __init__(self, optim_class: Any, **kwargs):
|
||||||
self.optim_class = optim_class
|
self.optim_class = optim_class
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||||
return self.optim_class(module.parameters(), **self.kwargs)
|
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AdamOptimizerFactory(OptimizerFactory):
|
class AdamOptimizerFactory(OptimizerFactory):
|
||||||
def __init__(self, lr):
|
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
|
||||||
self.lr = lr
|
return Adam(module.parameters(), lr=lr)
|
||||||
|
|
||||||
def create_optimizer(self, module: torch.nn.Module) -> Adam:
|
|
||||||
return Adam(module.parameters(), lr=self.lr)
|
|
||||||
|
|
||||||
|
|
||||||
class LRSchedulerFactory(ABC):
|
class LRSchedulerFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LinearLRSchedulerFactory(LRSchedulerFactory):
|
class LinearLRSchedulerFactory(LRSchedulerFactory):
|
||||||
def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig):
|
def __init__(self, sampling_config: RLSamplingConfig):
|
||||||
self.nn_config = nn_config
|
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
|
|
||||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||||
lr_scheduler = None
|
max_update_num = (
|
||||||
if self.nn_config.lr_decay:
|
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
|
||||||
max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs
|
* self.sampling_config.num_epochs
|
||||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
)
|
||||||
return lr_scheduler
|
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user