Add SAC high-level interface
This commit is contained in:
parent
2a1cc6bb55
commit
316eb3c579
@ -2,9 +2,9 @@ import warnings
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.config import RLSamplingConfig, BasicExperimentConfig
|
||||
from tianshou.config import BasicExperimentConfig, RLSamplingConfig
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.env import EnvFactory, Environments, ContinuousEnvironments
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||
|
||||
try:
|
||||
import envpool
|
||||
@ -12,9 +12,7 @@ except ImportError:
|
||||
envpool = None
|
||||
|
||||
|
||||
def make_mujoco_env(
|
||||
task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool
|
||||
):
|
||||
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
|
||||
"""Wrapper function for Mujoco env.
|
||||
|
||||
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
|
||||
|
||||
@ -1,359 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from jsonargparse import CLI
|
||||
from torch import nn
|
||||
from torch.distributions import Independent, Normal
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from mujoco_env import make_mujoco_env
|
||||
from tianshou.config import (
|
||||
BasicExperimentConfig,
|
||||
LoggerConfig,
|
||||
NNConfig,
|
||||
PGConfig,
|
||||
PPOConfig,
|
||||
RLAgentConfig,
|
||||
RLSamplingConfig,
|
||||
)
|
||||
from tianshou.config.utils import collect_configs
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import VectorEnvNormObs
|
||||
from tianshou.policy import BasePolicy, PPOPolicy
|
||||
from tianshou.trainer import OnpolicyTrainer
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
def set_seed(seed=42):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def get_logger_for_run(
|
||||
algo_name: str,
|
||||
task: str,
|
||||
logger_config: LoggerConfig,
|
||||
config: dict,
|
||||
seed: int,
|
||||
resume_id: Optional[Union[str, int]],
|
||||
) -> Tuple[str, Union[WandbLogger, TensorboardLogger]]:
|
||||
"""
|
||||
|
||||
:param algo_name:
|
||||
:param task:
|
||||
:param logger_config:
|
||||
:param config: the experiment config
|
||||
:param seed:
|
||||
:param resume_id: used as run_id by wandb, unused for tensorboard
|
||||
:return:
|
||||
"""
|
||||
"""Returns the log_path and logger."""
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, algo_name, str(seed), now)
|
||||
log_path = os.path.join(logger_config.logdir, log_name)
|
||||
|
||||
logger = get_logger(
|
||||
logger_config.logger,
|
||||
log_path,
|
||||
log_name=log_name,
|
||||
run_id=resume_id,
|
||||
config=config,
|
||||
wandb_project=logger_config.wandb_project,
|
||||
)
|
||||
return log_path, logger
|
||||
|
||||
|
||||
def get_continuous_env_info(
|
||||
env: gym.Env,
|
||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
|
||||
if not isinstance(env.action_space, gym.spaces.Box):
|
||||
raise ValueError(
|
||||
"Only environments with continuous action space are supported here. "
|
||||
f"But got env with action space: {env.action_space.__class__}."
|
||||
)
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
if not state_shape:
|
||||
raise ValueError("Observation space shape is not defined")
|
||||
action_shape = env.action_space.shape
|
||||
max_action = env.action_space.high[0]
|
||||
return state_shape, action_shape, max_action
|
||||
|
||||
|
||||
def resume_from_checkpoint(
|
||||
path: str,
|
||||
policy: BasePolicy,
|
||||
train_envs: VectorEnvNormObs | None = None,
|
||||
test_envs: VectorEnvNormObs | None = None,
|
||||
device: str | int | torch.device | None = None,
|
||||
):
|
||||
ckpt = torch.load(path, map_location=device)
|
||||
policy.load_state_dict(ckpt["model"])
|
||||
if train_envs:
|
||||
train_envs.set_obs_rms(ckpt["obs_rms"])
|
||||
if test_envs:
|
||||
test_envs.set_obs_rms(ckpt["obs_rms"])
|
||||
print("Loaded agent and obs. running means from: ", path)
|
||||
|
||||
|
||||
def watch_agent(n_episode, policy: BasePolicy, test_collector: Collector, render=0.0):
|
||||
policy.eval()
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=n_episode, render=render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
def get_train_test_collector(
|
||||
buffer_size: int,
|
||||
policy: BasePolicy,
|
||||
train_envs: VectorEnvNormObs,
|
||||
test_envs: VectorEnvNormObs,
|
||||
):
|
||||
if len(train_envs) > 1:
|
||||
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
return test_collector, train_collector
|
||||
|
||||
|
||||
TShape = Union[int, Sequence[int]]
|
||||
|
||||
|
||||
def get_actor_critic(
|
||||
state_shape: TShape,
|
||||
hidden_sizes: Sequence[int],
|
||||
action_shape: TShape,
|
||||
device: str | int | torch.device = "cpu",
|
||||
):
|
||||
net_a = Net(
|
||||
state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device
|
||||
)
|
||||
actor = ActorProb(net_a, action_shape, unbounded=True, device=device).to(device)
|
||||
net_c = Net(
|
||||
state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device
|
||||
)
|
||||
# TODO: twice device?
|
||||
critic = Critic(net_c, device=device).to(device)
|
||||
return actor, critic
|
||||
|
||||
|
||||
def get_logger(
|
||||
kind: Literal["wandb", "tensorboard"],
|
||||
log_path: str,
|
||||
log_name="",
|
||||
run_id: Optional[Union[str, int]] = None,
|
||||
config: Optional[Union[dict, argparse.Namespace]] = None,
|
||||
wandb_project: Optional[str] = None,
|
||||
):
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(config))
|
||||
if kind == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=run_id,
|
||||
config=config,
|
||||
project=wandb_project,
|
||||
)
|
||||
logger.load(writer)
|
||||
elif kind == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else:
|
||||
raise ValueError(f"Unknown logger: {kind}")
|
||||
return logger
|
||||
|
||||
|
||||
def get_lr_scheduler(optim, step_per_epoch: int, step_per_collect: int, epochs: int):
|
||||
"""Decay learning rate to 0 linearly."""
|
||||
max_update_num = np.ceil(step_per_epoch / step_per_collect) * epochs
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def init_and_get_optim(actor: nn.Module, critic: nn.Module, lr: float):
|
||||
"""Initializes layers of actor and critic.
|
||||
|
||||
:param actor:
|
||||
:param critic:
|
||||
:param lr:
|
||||
:return:
|
||||
"""
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||
for m in actor_critic.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
# orthogonal initialization
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
if hasattr(actor, "mu"):
|
||||
# For continuous action spaces with Gaussian policies
|
||||
# do last policy layer scaling, this will make initial actions have (close to)
|
||||
# 0 mean and std, and will help boost performances,
|
||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||
for m in actor.mu.modules():
|
||||
# TODO: seems like biases are initialized twice for the actor
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=lr)
|
||||
return optim
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: BasicExperimentConfig,
|
||||
logger_config: LoggerConfig,
|
||||
sampling_config: RLSamplingConfig,
|
||||
general_config: RLAgentConfig,
|
||||
pg_config: PGConfig,
|
||||
ppo_config: PPOConfig,
|
||||
nn_config: NNConfig,
|
||||
):
|
||||
"""
|
||||
Run the PPO test on the provided parameters.
|
||||
|
||||
:param experiment_config: BasicExperimentConfig - not ML or RL specific
|
||||
:param logger_config: LoggerConfig
|
||||
:param sampling_config: SamplingConfig -
|
||||
sampling, epochs, parallelization, buffers, collectors, and batching.
|
||||
:param general_config: RLAgentConfig - general RL agent config
|
||||
:param pg_config: PGConfig: common to most policy gradient algorithms
|
||||
:param ppo_config: PPOConfig - PPO specific config
|
||||
:param nn_config: NNConfig - NN-training specific config
|
||||
|
||||
:return: None
|
||||
"""
|
||||
full_config = collect_configs(*locals().values())
|
||||
set_seed(experiment_config.seed)
|
||||
|
||||
# create test and train envs, add env info to config
|
||||
env, train_envs, test_envs = make_mujoco_env(
|
||||
task=experiment_config.task,
|
||||
seed=experiment_config.seed,
|
||||
num_train_envs=sampling_config.num_train_envs,
|
||||
num_test_envs=sampling_config.num_test_envs,
|
||||
obs_norm=True,
|
||||
)
|
||||
|
||||
# adding env_info to logged config
|
||||
state_shape, action_shape, max_action = get_continuous_env_info(env)
|
||||
full_config["env_info"] = {
|
||||
"state_shape": state_shape,
|
||||
"action_shape": action_shape,
|
||||
"max_action": max_action,
|
||||
}
|
||||
log_path, logger = get_logger_for_run(
|
||||
"ppo",
|
||||
experiment_config.task,
|
||||
logger_config,
|
||||
full_config,
|
||||
experiment_config.seed,
|
||||
experiment_config.resume_id,
|
||||
)
|
||||
|
||||
# Setup NNs
|
||||
actor, critic = get_actor_critic(
|
||||
state_shape, nn_config.hidden_sizes, action_shape, experiment_config.device
|
||||
)
|
||||
optim = init_and_get_optim(actor, critic, nn_config.lr)
|
||||
|
||||
lr_scheduler = None
|
||||
if nn_config.lr_decay:
|
||||
lr_scheduler = get_lr_scheduler(
|
||||
optim,
|
||||
sampling_config.step_per_epoch,
|
||||
sampling_config.step_per_collect,
|
||||
sampling_config.num_epochs,
|
||||
)
|
||||
|
||||
# Create policy
|
||||
def dist_fn(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
policy = PPOPolicy(
|
||||
# nn-stuff
|
||||
actor,
|
||||
critic,
|
||||
optim,
|
||||
dist_fn=dist_fn,
|
||||
lr_scheduler=lr_scheduler,
|
||||
# env-stuff
|
||||
action_space=train_envs.action_space,
|
||||
action_scaling=True,
|
||||
# general_config
|
||||
discount_factor=general_config.gamma,
|
||||
gae_lambda=general_config.gae_lambda,
|
||||
reward_normalization=general_config.rew_norm,
|
||||
action_bound_method=general_config.action_bound_method,
|
||||
# pg_config
|
||||
max_grad_norm=pg_config.max_grad_norm,
|
||||
vf_coef=pg_config.vf_coef,
|
||||
ent_coef=pg_config.ent_coef,
|
||||
# ppo_config
|
||||
eps_clip=ppo_config.eps_clip,
|
||||
value_clip=ppo_config.value_clip,
|
||||
dual_clip=ppo_config.dual_clip,
|
||||
advantage_normalization=ppo_config.norm_adv,
|
||||
recompute_advantage=ppo_config.recompute_adv,
|
||||
)
|
||||
|
||||
if experiment_config.resume_path:
|
||||
resume_from_checkpoint(
|
||||
experiment_config.resume_path,
|
||||
policy,
|
||||
train_envs=train_envs,
|
||||
test_envs=test_envs,
|
||||
device=experiment_config.device,
|
||||
)
|
||||
|
||||
test_collector, train_collector = get_train_test_collector(
|
||||
sampling_config.buffer_size, policy, test_envs, train_envs
|
||||
)
|
||||
|
||||
# TODO: test num is the number of test envs but used as episode_per_test
|
||||
# here and in watch_agent
|
||||
if not experiment_config.watch:
|
||||
# RL training
|
||||
def save_best_fn(pol: nn.Module):
|
||||
state = {"model": pol.state_dict(), "obs_rms": train_envs.get_obs_rms()}
|
||||
torch.save(state, os.path.join(log_path, "policy.pth"))
|
||||
|
||||
trainer = OnpolicyTrainer(
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
batch_size=sampling_config.batch_size,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger,
|
||||
test_in_train=False,
|
||||
)
|
||||
result = trainer.run()
|
||||
pprint.pprint(result)
|
||||
|
||||
watch_agent(
|
||||
sampling_config.num_test_envs,
|
||||
policy,
|
||||
test_collector,
|
||||
render=experiment_config.render,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from jsonargparse import CLI
|
||||
from torch.distributions import Independent, Normal
|
||||
@ -10,17 +12,26 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.config import (
|
||||
BasicExperimentConfig,
|
||||
LoggerConfig,
|
||||
NNConfig,
|
||||
PGConfig,
|
||||
PPOConfig,
|
||||
RLAgentConfig,
|
||||
RLSamplingConfig,
|
||||
)
|
||||
from tianshou.highlevel.agent import PPOAgentFactory
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||
from tianshou.highlevel.module import ContinuousActorProbFactory, ContinuousNetCriticFactory
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
||||
from tianshou.highlevel.experiment import RLExperiment
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ContinuousActorProbFactory,
|
||||
ContinuousNetCriticFactory,
|
||||
)
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
||||
|
||||
|
||||
@dataclass
|
||||
class NNConfig:
|
||||
hidden_sizes: Sequence[int] = (64, 64)
|
||||
lr: float = 3e-4
|
||||
lr_decay: bool = True
|
||||
|
||||
|
||||
def main(
|
||||
@ -43,15 +54,22 @@ def main(
|
||||
|
||||
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
|
||||
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
|
||||
optim_factory = AdamOptimizerFactory(lr=nn_config.lr)
|
||||
lr_scheduler_factory = LinearLRSchedulerFactory(nn_config, sampling_config)
|
||||
agent_factory = PPOAgentFactory(general_config, pg_config, ppo_config, sampling_config, nn_config,
|
||||
actor_factory, critic_factory, optim_factory, dist_fn, lr_scheduler_factory)
|
||||
optim_factory = AdamOptimizerFactory()
|
||||
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if nn_config.lr_decay else None
|
||||
agent_factory = PPOAgentFactory(
|
||||
general_config,
|
||||
pg_config,
|
||||
ppo_config,
|
||||
sampling_config,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
optim_factory,
|
||||
dist_fn,
|
||||
nn_config.lr,
|
||||
lr_scheduler_factory,
|
||||
)
|
||||
|
||||
experiment = RLExperiment(experiment_config, logger_config, general_config, sampling_config,
|
||||
env_factory,
|
||||
logger_factory,
|
||||
agent_factory)
|
||||
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
|
||||
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
56
examples/mujoco/mujoco_sac_hl.py
Normal file
56
examples/mujoco/mujoco_sac_hl.py
Normal file
@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.config import (
|
||||
BasicExperimentConfig,
|
||||
LoggerConfig,
|
||||
RLSamplingConfig,
|
||||
)
|
||||
from tianshou.highlevel.agent import SACAgentFactory
|
||||
from tianshou.highlevel.experiment import RLExperiment
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||
from tianshou.highlevel.module import (
|
||||
ContinuousActorProbFactory,
|
||||
ContinuousNetCriticFactory,
|
||||
)
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: BasicExperimentConfig,
|
||||
logger_config: LoggerConfig,
|
||||
sampling_config: RLSamplingConfig,
|
||||
sac_config: SACAgentFactory.Config,
|
||||
hidden_sizes: Sequence[int] = (256, 256),
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now)
|
||||
logger_factory = DefaultLoggerFactory(logger_config)
|
||||
|
||||
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
|
||||
|
||||
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
|
||||
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
||||
optim_factory = AdamOptimizerFactory()
|
||||
agent_factory = SACAgentFactory(
|
||||
sac_config,
|
||||
sampling_config,
|
||||
actor_factory,
|
||||
critic_factory,
|
||||
critic_factory,
|
||||
optim_factory,
|
||||
)
|
||||
|
||||
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
|
||||
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
@ -63,7 +63,9 @@ envpool = ["envpool"]
|
||||
optional = true
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.7.0"
|
||||
docstring-parser = "^0.15"
|
||||
jinja2 = "*"
|
||||
jsonargparse = "^4.24.1"
|
||||
mypy = "^1.4.1"
|
||||
# networkx is used in a test
|
||||
networkx = "*"
|
||||
@ -83,8 +85,6 @@ sphinx_rtd_theme = "*"
|
||||
sphinxcontrib-bibtex = "*"
|
||||
sphinxcontrib-spelling = "^8.0.0"
|
||||
wandb = "^0.12.0"
|
||||
jsonargparse = "^4.24.1"
|
||||
docstring-parser = "^0.15"
|
||||
|
||||
[tool.mypy]
|
||||
allow_redefinition = true
|
||||
|
||||
@ -1 +1,10 @@
|
||||
from .config import *
|
||||
__all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"]
|
||||
|
||||
from .config import (
|
||||
BasicExperimentConfig,
|
||||
PGConfig,
|
||||
PPOConfig,
|
||||
RLAgentConfig,
|
||||
RLSamplingConfig,
|
||||
LoggerConfig,
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Sequence
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from jsonargparse import set_docstring_parse_options
|
||||
@ -14,12 +14,12 @@ class BasicExperimentConfig:
|
||||
seed: int = 42
|
||||
task: str = "Ant-v4"
|
||||
"""Mujoco specific"""
|
||||
render: Optional[float] = 0.0
|
||||
render: float | None = 0.0
|
||||
"""Milliseconds between rendered frames; if None, no rendering"""
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
resume_id: Optional[int] = None
|
||||
resume_id: str | None = None
|
||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||
resume_path: str = None
|
||||
resume_path: str | None = None
|
||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||
watch: bool = False
|
||||
"""If True, will not perform training and only watch the restored policy"""
|
||||
@ -28,7 +28,7 @@ class BasicExperimentConfig:
|
||||
|
||||
@dataclass
|
||||
class LoggerConfig:
|
||||
"""Logging config"""
|
||||
"""Logging config."""
|
||||
|
||||
logdir: str = "log"
|
||||
logger: Literal["tensorboard", "wandb"] = "tensorboard"
|
||||
@ -48,17 +48,18 @@ class RLSamplingConfig:
|
||||
buffer_size: int = 4096
|
||||
step_per_collect: int = 2048
|
||||
repeat_per_collect: int = 10
|
||||
update_per_step: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLAgentConfig:
|
||||
"""Config common to most RL algorithms"""
|
||||
"""Config common to most RL algorithms."""
|
||||
|
||||
gamma: float = 0.99
|
||||
"""Discount factor"""
|
||||
gae_lambda: float = 0.95
|
||||
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
|
||||
action_bound_method: Optional[Literal["clip", "tanh"]] = "clip"
|
||||
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
||||
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
|
||||
rew_norm: bool = True
|
||||
"""Whether to normalize rewards"""
|
||||
@ -66,7 +67,7 @@ class RLAgentConfig:
|
||||
|
||||
@dataclass
|
||||
class PGConfig:
|
||||
"""Config of general policy-gradient algorithms"""
|
||||
"""Config of general policy-gradient algorithms."""
|
||||
|
||||
ent_coef: float = 0.0
|
||||
vf_coef: float = 0.25
|
||||
@ -75,18 +76,11 @@ class PGConfig:
|
||||
|
||||
@dataclass
|
||||
class PPOConfig:
|
||||
"""PPO specific config"""
|
||||
"""PPO specific config."""
|
||||
|
||||
value_clip: bool = False
|
||||
norm_adv: bool = False
|
||||
"""Whether to normalize advantages"""
|
||||
eps_clip: float = 0.2
|
||||
dual_clip: Optional[float] = None
|
||||
dual_clip: float | None = None
|
||||
recompute_adv: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class NNConfig:
|
||||
hidden_sizes: Sequence[int] = (64, 64)
|
||||
lr: float = 3e-4
|
||||
lr_decay: bool = True
|
||||
|
||||
@ -2,11 +2,10 @@ from dataclasses import asdict, is_dataclass
|
||||
|
||||
|
||||
def collect_configs(*confs):
|
||||
"""
|
||||
Collect instances of dataclasses to a single dict mapping the
|
||||
classname to the values. If any of the passed objects is not a
|
||||
dataclass or if two instances of the same config class are passed,
|
||||
an error will be raised.
|
||||
"""Collect instances of dataclasses to a single dict mapping the classname to the values.
|
||||
|
||||
If any of the passed objects is not a ddataclass or if two instances
|
||||
of the same config class are passed, an error will be raised.
|
||||
|
||||
:param confs: dataclasses
|
||||
:return: Dictionary mapping class names to their instances.
|
||||
|
||||
@ -1,33 +1,51 @@
|
||||
import os
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Callable
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig
|
||||
from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector
|
||||
from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory
|
||||
from tianshou.policy import BasePolicy, PPOPolicy
|
||||
from tianshou.trainer import BaseTrainer, OnpolicyTrainer
|
||||
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
||||
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
|
||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||
|
||||
|
||||
class AgentFactory(ABC):
|
||||
def __init__(self, sampling_config: RLSamplingConfig):
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
||||
buffer_size = self.sampling_config.buffer_size
|
||||
train_envs = envs.train_envs
|
||||
if len(train_envs) > 1:
|
||||
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, envs.test_envs)
|
||||
return train_collector, test_collector
|
||||
|
||||
@abstractmethod
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
||||
def save_best_fn(pol: torch.nn.Module):
|
||||
state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()}
|
||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||
state = {
|
||||
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
||||
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
||||
}
|
||||
torch.save(state, os.path.join(log_path, "policy.pth"))
|
||||
|
||||
return save_best_fn
|
||||
@ -43,36 +61,26 @@ class AgentFactory(ABC):
|
||||
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
||||
|
||||
@abstractmethod
|
||||
def create_train_test_collector(self,
|
||||
policy: BasePolicy,
|
||||
envs: Environments):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
||||
envs: Environments, logger: Logger) -> BaseTrainer:
|
||||
def create_trainer(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
envs: Environments,
|
||||
logger: Logger,
|
||||
) -> BaseTrainer:
|
||||
pass
|
||||
|
||||
|
||||
class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
def __init__(self, sampling_config: RLSamplingConfig):
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
def create_train_test_collector(self,
|
||||
policy: BasePolicy,
|
||||
envs: Environments):
|
||||
buffer_size = self.sampling_config.buffer_size
|
||||
train_envs = envs.train_envs
|
||||
if len(train_envs) > 1:
|
||||
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, envs.test_envs)
|
||||
return train_collector, test_collector
|
||||
|
||||
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
||||
envs: Environments, logger: Logger) -> OnpolicyTrainer:
|
||||
def create_trainer(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
envs: Environments,
|
||||
logger: Logger,
|
||||
) -> OnpolicyTrainer:
|
||||
sampling_config = self.sampling_config
|
||||
return OnpolicyTrainer(
|
||||
policy=policy,
|
||||
@ -90,17 +98,46 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
)
|
||||
|
||||
|
||||
class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
def create_trainer(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
envs: Environments,
|
||||
logger: Logger,
|
||||
) -> OffpolicyTrainer:
|
||||
sampling_config = self.sampling_config
|
||||
return OffpolicyTrainer(
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
batch_size=sampling_config.batch_size,
|
||||
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
||||
logger=logger.logger,
|
||||
update_per_step=sampling_config.update_per_step,
|
||||
test_in_train=False,
|
||||
)
|
||||
|
||||
|
||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
def __init__(self, general_config: RLAgentConfig,
|
||||
pg_config: PGConfig,
|
||||
ppo_config: PPOConfig,
|
||||
sampling_config: RLSamplingConfig,
|
||||
nn_config: NNConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
dist_fn,
|
||||
lr_scheduler_factory: LRSchedulerFactory):
|
||||
def __init__(
|
||||
self,
|
||||
general_config: RLAgentConfig,
|
||||
pg_config: PGConfig,
|
||||
ppo_config: PPOConfig,
|
||||
sampling_config: RLSamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
dist_fn,
|
||||
lr: float,
|
||||
lr_scheduler_factory: LRSchedulerFactory | None = None,
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
self.optimizer_factory = optimizer_factory
|
||||
self.critic_factory = critic_factory
|
||||
@ -108,16 +145,19 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
self.ppo_config = ppo_config
|
||||
self.pg_config = pg_config
|
||||
self.general_config = general_config
|
||||
self.lr = lr
|
||||
self.lr_scheduler_factory = lr_scheduler_factory
|
||||
self.dist_fn = dist_fn
|
||||
self.nn_config = nn_config
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=False)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optimizer_factory.create_optimizer(actor_critic)
|
||||
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
|
||||
optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr)
|
||||
if self.lr_scheduler_factory is not None:
|
||||
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
|
||||
else:
|
||||
lr_scheduler = None
|
||||
return PPOPolicy(
|
||||
# nn-stuff
|
||||
actor,
|
||||
@ -144,3 +184,60 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
advantage_normalization=self.ppo_config.norm_adv,
|
||||
recompute_advantage=self.ppo_config.recompute_adv,
|
||||
)
|
||||
|
||||
|
||||
class SACAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
config: "SACAgentFactory.Config",
|
||||
sampling_config: RLSamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic1_factory: CriticFactory,
|
||||
critic2_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
exploration_noise: BaseNoise | None = None,
|
||||
):
|
||||
super().__init__(sampling_config)
|
||||
self.critic2_factory = critic2_factory
|
||||
self.critic1_factory = critic1_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.exploration_noise = exploration_noise
|
||||
self.optim_factory = optim_factory
|
||||
self.config = config
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
|
||||
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
|
||||
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
|
||||
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
|
||||
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
|
||||
return SACPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critic1,
|
||||
critic1_optim,
|
||||
critic2,
|
||||
critic2_optim,
|
||||
tau=self.config.tau,
|
||||
gamma=self.config.gamma,
|
||||
alpha=self.config.alpha,
|
||||
estimation_step=self.config.estimation_step,
|
||||
action_space=envs.get_action_space(),
|
||||
deterministic_eval=self.config.deterministic_eval,
|
||||
exploration_noise=self.exploration_noise,
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""SAC configuration."""
|
||||
|
||||
tau: float = 0.005
|
||||
gamma: float = 0.99
|
||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
|
||||
reward_normalization: bool = False
|
||||
estimation_step: int = 1
|
||||
deterministic_eval: bool = True
|
||||
actor_lr: float = 1e-3
|
||||
critic1_lr: float = 1e-3
|
||||
critic2_lr: float = 1e-3
|
||||
|
||||
@ -1,24 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, Dict, Any, Union, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
|
||||
TShape = Union[int, Sequence[int]]
|
||||
TShape = int | Sequence[int]
|
||||
|
||||
|
||||
class Environments(ABC):
|
||||
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
self.env = env
|
||||
self.train_envs = train_envs
|
||||
self.test_envs = test_envs
|
||||
|
||||
def info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"action_shape": self.get_action_shape(),
|
||||
"state_shape": self.get_state_shape()
|
||||
}
|
||||
def info(self) -> dict[str, Any]:
|
||||
return {"action_shape": self.get_action_shape(), "state_shape": self.get_state_shape()}
|
||||
|
||||
@abstractmethod
|
||||
def get_action_shape(self) -> TShape:
|
||||
@ -33,7 +31,7 @@ class Environments(ABC):
|
||||
|
||||
|
||||
class ContinuousEnvironments(Environments):
|
||||
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||
|
||||
@ -44,12 +42,12 @@ class ContinuousEnvironments(Environments):
|
||||
|
||||
@staticmethod
|
||||
def _get_continuous_env_info(
|
||||
env: gym.Env,
|
||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
|
||||
env: gym.Env,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], float]:
|
||||
if not isinstance(env.action_space, gym.spaces.Box):
|
||||
raise ValueError(
|
||||
"Only environments with continuous action space are supported here. "
|
||||
f"But got env with action space: {env.action_space.__class__}."
|
||||
f"But got env with action space: {env.action_space.__class__}.",
|
||||
)
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
if not state_shape:
|
||||
@ -68,4 +66,4 @@ class ContinuousEnvironments(Environments):
|
||||
class EnvFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_envs(self) -> Environments:
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -4,7 +4,9 @@ from typing import Generic, TypeVar
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig
|
||||
from tianshou.config import (
|
||||
BasicExperimentConfig,
|
||||
)
|
||||
from tianshou.data import Collector
|
||||
from tianshou.highlevel.agent import AgentFactory
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
@ -17,23 +19,19 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
||||
|
||||
|
||||
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
def __init__(self,
|
||||
config: BasicExperimentConfig,
|
||||
logger_config: LoggerConfig,
|
||||
general_config: RLAgentConfig,
|
||||
sampling_config: RLSamplingConfig,
|
||||
env_factory: EnvFactory,
|
||||
logger_factory: LoggerFactory,
|
||||
agent_factory: AgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
config: BasicExperimentConfig,
|
||||
env_factory: EnvFactory,
|
||||
logger_factory: LoggerFactory,
|
||||
agent_factory: AgentFactory,
|
||||
):
|
||||
self.config = config
|
||||
self.logger_config = logger_config
|
||||
self.general_config = general_config
|
||||
self.sampling_config = sampling_config
|
||||
self.env_factory = env_factory
|
||||
self.logger_factory = logger_factory
|
||||
self.agent_factory = agent_factory
|
||||
|
||||
def _set_seed(self):
|
||||
def _set_seed(self) -> None:
|
||||
seed = self.config.seed
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
@ -43,7 +41,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
# TODO
|
||||
}
|
||||
|
||||
def run(self, log_name: str):
|
||||
def run(self, log_name: str) -> None:
|
||||
self._set_seed()
|
||||
|
||||
envs = self.env_factory.create_envs()
|
||||
@ -52,25 +50,47 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
full_config.update(envs.info())
|
||||
|
||||
run_id = self.config.resume_id
|
||||
logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config)
|
||||
logger = self.logger_factory.create_logger(
|
||||
log_name=log_name,
|
||||
run_id=run_id,
|
||||
config_dict=full_config,
|
||||
)
|
||||
|
||||
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||
if self.config.resume_path:
|
||||
self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device)
|
||||
self.agent_factory.load_checkpoint(
|
||||
policy,
|
||||
self.config.resume_path,
|
||||
envs,
|
||||
self.config.device,
|
||||
)
|
||||
|
||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs)
|
||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
||||
policy,
|
||||
envs,
|
||||
)
|
||||
|
||||
if not self.config.watch:
|
||||
trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger)
|
||||
trainer = self.agent_factory.create_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
envs,
|
||||
logger,
|
||||
)
|
||||
result = trainer.run()
|
||||
pprint(result) # TODO logging
|
||||
|
||||
self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render)
|
||||
self._watch_agent(
|
||||
self.config.watch_num_episodes,
|
||||
policy,
|
||||
test_collector,
|
||||
self.config.render,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render):
|
||||
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
|
||||
policy.eval()
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, Optional
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.config import LoggerConfig
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
|
||||
|
||||
TLogger = Union[TensorboardLogger, WandbLogger]
|
||||
TLogger = TensorboardLogger | WandbLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -20,7 +18,7 @@ class Logger:
|
||||
|
||||
class LoggerFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
||||
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
||||
pass
|
||||
|
||||
|
||||
@ -28,7 +26,7 @@ class DefaultLoggerFactory(LoggerFactory):
|
||||
def __init__(self, config: LoggerConfig):
|
||||
self.config = config
|
||||
|
||||
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
||||
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
||||
writer = SummaryWriter(self.config.logdir)
|
||||
writer.add_text("args", str(self.config))
|
||||
if self.config.logger == "wandb":
|
||||
|
||||
@ -1,24 +1,24 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Sequence
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
from tianshou.utils.net.continuous import Critic as ContinuousCritic
|
||||
|
||||
TDevice = str | int | torch.device
|
||||
|
||||
|
||||
def init_linear_orthogonal(m: torch.nn.Module):
|
||||
"""
|
||||
Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0
|
||||
def init_linear_orthogonal(module: torch.nn.Module):
|
||||
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
||||
|
||||
:param m: the module whose submodules are to be processed
|
||||
:param module: the module whose submodules are to be processed
|
||||
"""
|
||||
for m in m.modules():
|
||||
for m in module.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
@ -31,9 +31,9 @@ class ActorFactory(ABC):
|
||||
|
||||
@staticmethod
|
||||
def _init_linear(actor: torch.nn.Module):
|
||||
"""
|
||||
Initializes linear layers of an actor module using default mechanisms
|
||||
:param module: the actor module
|
||||
"""Initializes linear layers of an actor module using default mechanisms.
|
||||
|
||||
:param module: the actor module.
|
||||
"""
|
||||
init_linear_orthogonal(actor)
|
||||
if hasattr(actor, "mu"):
|
||||
@ -51,17 +51,29 @@ class ContinuousActorFactory(ActorFactory, ABC):
|
||||
|
||||
|
||||
class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.unbounded = unbounded
|
||||
self.conditioned_sigma = conditioned_sigma
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
net_a = Net(
|
||||
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
|
||||
envs.get_state_shape(),
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
)
|
||||
actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device)
|
||||
actor = ActorProb(
|
||||
net_a,
|
||||
envs.get_action_shape(),
|
||||
unbounded=True,
|
||||
device=device,
|
||||
conditioned_sigma=self.conditioned_sigma,
|
||||
).to(device)
|
||||
|
||||
# init params
|
||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||
if not self.conditioned_sigma:
|
||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||
self._init_linear(actor)
|
||||
|
||||
return actor
|
||||
@ -69,7 +81,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||
|
||||
class CriticFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
pass
|
||||
|
||||
|
||||
@ -78,12 +90,19 @@ class ContinuousCriticFactory(CriticFactory, ABC):
|
||||
|
||||
|
||||
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
def __init__(self, hidden_sizes: Sequence[int], action_shape=0):
|
||||
self.action_shape = action_shape
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
|
||||
envs.get_state_shape(),
|
||||
action_shape=action_shape,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
concat=use_action,
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
)
|
||||
critic = ContinuousCritic(net_c, device=device).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
|
||||
@ -1,54 +1,51 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union, Iterable, Dict, Any, Optional
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from tianshou.config import RLSamplingConfig, NNConfig
|
||||
from tianshou.config import RLSamplingConfig
|
||||
|
||||
TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
|
||||
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||
|
||||
|
||||
class OptimizerFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||
pass
|
||||
|
||||
|
||||
class TorchOptimizerFactory(OptimizerFactory):
|
||||
def __init__(self, optim_class, **kwargs):
|
||||
def __init__(self, optim_class: Any, **kwargs):
|
||||
self.optim_class = optim_class
|
||||
self.kwargs = kwargs
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
||||
return self.optim_class(module.parameters(), **self.kwargs)
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
|
||||
|
||||
|
||||
class AdamOptimizerFactory(OptimizerFactory):
|
||||
def __init__(self, lr):
|
||||
self.lr = lr
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module) -> Adam:
|
||||
return Adam(module.parameters(), lr=self.lr)
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
|
||||
return Adam(module.parameters(), lr=lr)
|
||||
|
||||
|
||||
class LRSchedulerFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
pass
|
||||
|
||||
|
||||
class LinearLRSchedulerFactory(LRSchedulerFactory):
|
||||
def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig):
|
||||
self.nn_config = nn_config
|
||||
def __init__(self, sampling_config: RLSamplingConfig):
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
||||
lr_scheduler = None
|
||||
if self.nn_config.lr_decay:
|
||||
max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
return lr_scheduler
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
max_update_num = (
|
||||
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
|
||||
* self.sampling_config.num_epochs
|
||||
)
|
||||
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user