Add SAC high-level interface

This commit is contained in:
Dominik Jain 2023-09-20 09:29:34 +02:00
parent 2a1cc6bb55
commit 316eb3c579
14 changed files with 377 additions and 533 deletions

View File

@ -2,9 +2,9 @@ import warnings
import gymnasium as gym
from tianshou.config import RLSamplingConfig, BasicExperimentConfig
from tianshou.config import BasicExperimentConfig, RLSamplingConfig
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import EnvFactory, Environments, ContinuousEnvironments
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
try:
import envpool
@ -12,9 +12,7 @@ except ImportError:
envpool = None
def make_mujoco_env(
task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool
):
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
"""Wrapper function for Mujoco env.
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.

View File

@ -1,359 +0,0 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
from collections.abc import Sequence
from typing import Literal, Optional, Tuple, Union
import gymnasium as gym
import numpy as np
import torch
from jsonargparse import CLI
from torch import nn
from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from mujoco_env import make_mujoco_env
from tianshou.config import (
BasicExperimentConfig,
LoggerConfig,
NNConfig,
PGConfig,
PPOConfig,
RLAgentConfig,
RLSamplingConfig,
)
from tianshou.config.utils import collect_configs
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import VectorEnvNormObs
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
def set_seed(seed=42):
np.random.seed(seed)
torch.manual_seed(seed)
def get_logger_for_run(
algo_name: str,
task: str,
logger_config: LoggerConfig,
config: dict,
seed: int,
resume_id: Optional[Union[str, int]],
) -> Tuple[str, Union[WandbLogger, TensorboardLogger]]:
"""
:param algo_name:
:param task:
:param logger_config:
:param config: the experiment config
:param seed:
:param resume_id: used as run_id by wandb, unused for tensorboard
:return:
"""
"""Returns the log_path and logger."""
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, algo_name, str(seed), now)
log_path = os.path.join(logger_config.logdir, log_name)
logger = get_logger(
logger_config.logger,
log_path,
log_name=log_name,
run_id=resume_id,
config=config,
wandb_project=logger_config.wandb_project,
)
return log_path, logger
def get_continuous_env_info(
env: gym.Env,
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
if not isinstance(env.action_space, gym.spaces.Box):
raise ValueError(
"Only environments with continuous action space are supported here. "
f"But got env with action space: {env.action_space.__class__}."
)
state_shape = env.observation_space.shape or env.observation_space.n
if not state_shape:
raise ValueError("Observation space shape is not defined")
action_shape = env.action_space.shape
max_action = env.action_space.high[0]
return state_shape, action_shape, max_action
def resume_from_checkpoint(
path: str,
policy: BasePolicy,
train_envs: VectorEnvNormObs | None = None,
test_envs: VectorEnvNormObs | None = None,
device: str | int | torch.device | None = None,
):
ckpt = torch.load(path, map_location=device)
policy.load_state_dict(ckpt["model"])
if train_envs:
train_envs.set_obs_rms(ckpt["obs_rms"])
if test_envs:
test_envs.set_obs_rms(ckpt["obs_rms"])
print("Loaded agent and obs. running means from: ", path)
def watch_agent(n_episode, policy: BasePolicy, test_collector: Collector, render=0.0):
policy.eval()
test_collector.reset()
result = test_collector.collect(n_episode=n_episode, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
def get_train_test_collector(
buffer_size: int,
policy: BasePolicy,
train_envs: VectorEnvNormObs,
test_envs: VectorEnvNormObs,
):
if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
return test_collector, train_collector
TShape = Union[int, Sequence[int]]
def get_actor_critic(
state_shape: TShape,
hidden_sizes: Sequence[int],
action_shape: TShape,
device: str | int | torch.device = "cpu",
):
net_a = Net(
state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device
)
actor = ActorProb(net_a, action_shape, unbounded=True, device=device).to(device)
net_c = Net(
state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device
)
# TODO: twice device?
critic = Critic(net_c, device=device).to(device)
return actor, critic
def get_logger(
kind: Literal["wandb", "tensorboard"],
log_path: str,
log_name="",
run_id: Optional[Union[str, int]] = None,
config: Optional[Union[dict, argparse.Namespace]] = None,
wandb_project: Optional[str] = None,
):
writer = SummaryWriter(log_path)
writer.add_text("args", str(config))
if kind == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=run_id,
config=config,
project=wandb_project,
)
logger.load(writer)
elif kind == "tensorboard":
logger = TensorboardLogger(writer)
else:
raise ValueError(f"Unknown logger: {kind}")
return logger
def get_lr_scheduler(optim, step_per_epoch: int, step_per_collect: int, epochs: int):
"""Decay learning rate to 0 linearly."""
max_update_num = np.ceil(step_per_epoch / step_per_collect) * epochs
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
return lr_scheduler
def init_and_get_optim(actor: nn.Module, critic: nn.Module, lr: float):
"""Initializes layers of actor and critic.
:param actor:
:param critic:
:param lr:
:return:
"""
actor_critic = ActorCritic(actor, critic)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
if hasattr(actor, "mu"):
# For continuous action spaces with Gaussian policies
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
# TODO: seems like biases are initialized twice for the actor
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(actor_critic.parameters(), lr=lr)
return optim
def main(
experiment_config: BasicExperimentConfig,
logger_config: LoggerConfig,
sampling_config: RLSamplingConfig,
general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
nn_config: NNConfig,
):
"""
Run the PPO test on the provided parameters.
:param experiment_config: BasicExperimentConfig - not ML or RL specific
:param logger_config: LoggerConfig
:param sampling_config: SamplingConfig -
sampling, epochs, parallelization, buffers, collectors, and batching.
:param general_config: RLAgentConfig - general RL agent config
:param pg_config: PGConfig: common to most policy gradient algorithms
:param ppo_config: PPOConfig - PPO specific config
:param nn_config: NNConfig - NN-training specific config
:return: None
"""
full_config = collect_configs(*locals().values())
set_seed(experiment_config.seed)
# create test and train envs, add env info to config
env, train_envs, test_envs = make_mujoco_env(
task=experiment_config.task,
seed=experiment_config.seed,
num_train_envs=sampling_config.num_train_envs,
num_test_envs=sampling_config.num_test_envs,
obs_norm=True,
)
# adding env_info to logged config
state_shape, action_shape, max_action = get_continuous_env_info(env)
full_config["env_info"] = {
"state_shape": state_shape,
"action_shape": action_shape,
"max_action": max_action,
}
log_path, logger = get_logger_for_run(
"ppo",
experiment_config.task,
logger_config,
full_config,
experiment_config.seed,
experiment_config.resume_id,
)
# Setup NNs
actor, critic = get_actor_critic(
state_shape, nn_config.hidden_sizes, action_shape, experiment_config.device
)
optim = init_and_get_optim(actor, critic, nn_config.lr)
lr_scheduler = None
if nn_config.lr_decay:
lr_scheduler = get_lr_scheduler(
optim,
sampling_config.step_per_epoch,
sampling_config.step_per_collect,
sampling_config.num_epochs,
)
# Create policy
def dist_fn(*logits):
return Independent(Normal(*logits), 1)
policy = PPOPolicy(
# nn-stuff
actor,
critic,
optim,
dist_fn=dist_fn,
lr_scheduler=lr_scheduler,
# env-stuff
action_space=train_envs.action_space,
action_scaling=True,
# general_config
discount_factor=general_config.gamma,
gae_lambda=general_config.gae_lambda,
reward_normalization=general_config.rew_norm,
action_bound_method=general_config.action_bound_method,
# pg_config
max_grad_norm=pg_config.max_grad_norm,
vf_coef=pg_config.vf_coef,
ent_coef=pg_config.ent_coef,
# ppo_config
eps_clip=ppo_config.eps_clip,
value_clip=ppo_config.value_clip,
dual_clip=ppo_config.dual_clip,
advantage_normalization=ppo_config.norm_adv,
recompute_advantage=ppo_config.recompute_adv,
)
if experiment_config.resume_path:
resume_from_checkpoint(
experiment_config.resume_path,
policy,
train_envs=train_envs,
test_envs=test_envs,
device=experiment_config.device,
)
test_collector, train_collector = get_train_test_collector(
sampling_config.buffer_size, policy, test_envs, train_envs
)
# TODO: test num is the number of test envs but used as episode_per_test
# here and in watch_agent
if not experiment_config.watch:
# RL training
def save_best_fn(pol: nn.Module):
state = {"model": pol.state_dict(), "obs_rms": train_envs.get_obs_rms()}
torch.save(state, os.path.join(log_path, "policy.pth"))
trainer = OnpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False,
)
result = trainer.run()
pprint.pprint(result)
watch_agent(
sampling_config.num_test_envs,
policy,
test_collector,
render=experiment_config.render,
)
if __name__ == "__main__":
CLI(main)

View File

@ -2,6 +2,8 @@
import datetime
import os
from collections.abc import Sequence
from dataclasses import dataclass
from jsonargparse import CLI
from torch.distributions import Independent, Normal
@ -10,17 +12,26 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.config import (
BasicExperimentConfig,
LoggerConfig,
NNConfig,
PGConfig,
PPOConfig,
RLAgentConfig,
RLSamplingConfig,
)
from tianshou.highlevel.agent import PPOAgentFactory
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import ContinuousActorProbFactory, ContinuousNetCriticFactory
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
from tianshou.highlevel.experiment import RLExperiment
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
@dataclass
class NNConfig:
hidden_sizes: Sequence[int] = (64, 64)
lr: float = 3e-4
lr_decay: bool = True
def main(
@ -43,15 +54,22 @@ def main(
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
optim_factory = AdamOptimizerFactory(lr=nn_config.lr)
lr_scheduler_factory = LinearLRSchedulerFactory(nn_config, sampling_config)
agent_factory = PPOAgentFactory(general_config, pg_config, ppo_config, sampling_config, nn_config,
actor_factory, critic_factory, optim_factory, dist_fn, lr_scheduler_factory)
optim_factory = AdamOptimizerFactory()
lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if nn_config.lr_decay else None
agent_factory = PPOAgentFactory(
general_config,
pg_config,
ppo_config,
sampling_config,
actor_factory,
critic_factory,
optim_factory,
dist_fn,
nn_config.lr,
lr_scheduler_factory,
)
experiment = RLExperiment(experiment_config, logger_config, general_config, sampling_config,
env_factory,
logger_factory,
agent_factory)
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
experiment.run(log_name)

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.config import (
BasicExperimentConfig,
LoggerConfig,
RLSamplingConfig,
)
from tianshou.highlevel.agent import SACAgentFactory
from tianshou.highlevel.experiment import RLExperiment
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory
def main(
experiment_config: BasicExperimentConfig,
logger_config: LoggerConfig,
sampling_config: RLSamplingConfig,
sac_config: SACAgentFactory.Config,
hidden_sizes: Sequence[int] = (256, 256),
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory(logger_config)
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
optim_factory = AdamOptimizerFactory()
agent_factory = SACAgentFactory(
sac_config,
sampling_config,
actor_factory,
critic_factory,
critic_factory,
optim_factory,
)
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)

View File

@ -63,7 +63,9 @@ envpool = ["envpool"]
optional = true
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
docstring-parser = "^0.15"
jinja2 = "*"
jsonargparse = "^4.24.1"
mypy = "^1.4.1"
# networkx is used in a test
networkx = "*"
@ -83,8 +85,6 @@ sphinx_rtd_theme = "*"
sphinxcontrib-bibtex = "*"
sphinxcontrib-spelling = "^8.0.0"
wandb = "^0.12.0"
jsonargparse = "^4.24.1"
docstring-parser = "^0.15"
[tool.mypy]
allow_redefinition = true

View File

@ -1 +1,10 @@
from .config import *
__all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"]
from .config import (
BasicExperimentConfig,
PGConfig,
PPOConfig,
RLAgentConfig,
RLSamplingConfig,
LoggerConfig,
)

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Literal, Optional, Sequence
from typing import Literal
import torch
from jsonargparse import set_docstring_parse_options
@ -14,12 +14,12 @@ class BasicExperimentConfig:
seed: int = 42
task: str = "Ant-v4"
"""Mujoco specific"""
render: Optional[float] = 0.0
render: float | None = 0.0
"""Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: Optional[int] = None
resume_id: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
resume_path: str = None
resume_path: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
watch: bool = False
"""If True, will not perform training and only watch the restored policy"""
@ -28,7 +28,7 @@ class BasicExperimentConfig:
@dataclass
class LoggerConfig:
"""Logging config"""
"""Logging config."""
logdir: str = "log"
logger: Literal["tensorboard", "wandb"] = "tensorboard"
@ -48,17 +48,18 @@ class RLSamplingConfig:
buffer_size: int = 4096
step_per_collect: int = 2048
repeat_per_collect: int = 10
update_per_step: int = 1
@dataclass
class RLAgentConfig:
"""Config common to most RL algorithms"""
"""Config common to most RL algorithms."""
gamma: float = 0.99
"""Discount factor"""
gae_lambda: float = 0.95
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
action_bound_method: Optional[Literal["clip", "tanh"]] = "clip"
action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
rew_norm: bool = True
"""Whether to normalize rewards"""
@ -66,7 +67,7 @@ class RLAgentConfig:
@dataclass
class PGConfig:
"""Config of general policy-gradient algorithms"""
"""Config of general policy-gradient algorithms."""
ent_coef: float = 0.0
vf_coef: float = 0.25
@ -75,18 +76,11 @@ class PGConfig:
@dataclass
class PPOConfig:
"""PPO specific config"""
"""PPO specific config."""
value_clip: bool = False
norm_adv: bool = False
"""Whether to normalize advantages"""
eps_clip: float = 0.2
dual_clip: Optional[float] = None
dual_clip: float | None = None
recompute_adv: bool = True
@dataclass
class NNConfig:
hidden_sizes: Sequence[int] = (64, 64)
lr: float = 3e-4
lr_decay: bool = True

View File

@ -2,11 +2,10 @@ from dataclasses import asdict, is_dataclass
def collect_configs(*confs):
"""
Collect instances of dataclasses to a single dict mapping the
classname to the values. If any of the passed objects is not a
dataclass or if two instances of the same config class are passed,
an error will be raised.
"""Collect instances of dataclasses to a single dict mapping the classname to the values.
If any of the passed objects is not a ddataclass or if two instances
of the same config class are passed, an error will be raised.
:param confs: dataclasses
:return: Dictionary mapping class names to their instances.

View File

@ -1,33 +1,51 @@
import os
from abc import abstractmethod, ABC
from typing import Callable
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
import torch
from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig
from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector
from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.trainer import BaseTrainer, OnpolicyTrainer
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
class AgentFactory(ABC):
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs
if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
return train_collector, test_collector
@abstractmethod
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
pass
@staticmethod
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
def save_best_fn(pol: torch.nn.Module):
state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()}
def save_best_fn(pol: torch.nn.Module) -> None:
state = {
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
}
torch.save(state, os.path.join(log_path, "policy.pth"))
return save_best_fn
@ -43,36 +61,26 @@ class AgentFactory(ABC):
print("Loaded agent and obs. running means from: ", path) # TODO logging
@abstractmethod
def create_train_test_collector(self,
policy: BasePolicy,
envs: Environments):
pass
@abstractmethod
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
envs: Environments, logger: Logger) -> BaseTrainer:
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> BaseTrainer:
pass
class OnpolicyAgentFactory(AgentFactory, ABC):
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_train_test_collector(self,
policy: BasePolicy,
envs: Environments):
buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs
if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
return train_collector, test_collector
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
envs: Environments, logger: Logger) -> OnpolicyTrainer:
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> OnpolicyTrainer:
sampling_config = self.sampling_config
return OnpolicyTrainer(
policy=policy,
@ -90,17 +98,46 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
)
class OffpolicyAgentFactory(AgentFactory, ABC):
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> OffpolicyTrainer:
sampling_config = self.sampling_config
return OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
logger=logger.logger,
update_per_step=sampling_config.update_per_step,
test_in_train=False,
)
class PPOAgentFactory(OnpolicyAgentFactory):
def __init__(self, general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
sampling_config: RLSamplingConfig,
nn_config: NNConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
dist_fn,
lr_scheduler_factory: LRSchedulerFactory):
def __init__(
self,
general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
dist_fn,
lr: float,
lr_scheduler_factory: LRSchedulerFactory | None = None,
):
super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory
@ -108,16 +145,19 @@ class PPOAgentFactory(OnpolicyAgentFactory):
self.ppo_config = ppo_config
self.pg_config = pg_config
self.general_config = general_config
self.lr = lr
self.lr_scheduler_factory = lr_scheduler_factory
self.dist_fn = dist_fn
self.nn_config = nn_config
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=False)
actor_critic = ActorCritic(actor, critic)
optim = self.optimizer_factory.create_optimizer(actor_critic)
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr)
if self.lr_scheduler_factory is not None:
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
else:
lr_scheduler = None
return PPOPolicy(
# nn-stuff
actor,
@ -144,3 +184,60 @@ class PPOAgentFactory(OnpolicyAgentFactory):
advantage_normalization=self.ppo_config.norm_adv,
recompute_advantage=self.ppo_config.recompute_adv,
)
class SACAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
config: "SACAgentFactory.Config",
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
exploration_noise: BaseNoise | None = None,
):
super().__init__(sampling_config)
self.critic2_factory = critic2_factory
self.critic1_factory = critic1_factory
self.actor_factory = actor_factory
self.exploration_noise = exploration_noise
self.optim_factory = optim_factory
self.config = config
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module(envs, device)
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
return SACPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
tau=self.config.tau,
gamma=self.config.gamma,
alpha=self.config.alpha,
estimation_step=self.config.estimation_step,
action_space=envs.get_action_space(),
deterministic_eval=self.config.deterministic_eval,
exploration_noise=self.exploration_noise,
)
@dataclass
class Config:
"""SAC configuration."""
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
reward_normalization: bool = False
estimation_step: int = 1
deterministic_eval: bool = True
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3

View File

@ -1,24 +1,22 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, Any, Union, Sequence
from collections.abc import Sequence
from typing import Any
import gymnasium as gym
from tianshou.env import BaseVectorEnv
TShape = Union[int, Sequence[int]]
TShape = int | Sequence[int]
class Environments(ABC):
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
self.train_envs = train_envs
self.test_envs = test_envs
def info(self) -> Dict[str, Any]:
return {
"action_shape": self.get_action_shape(),
"state_shape": self.get_state_shape()
}
def info(self) -> dict[str, Any]:
return {"action_shape": self.get_action_shape(), "state_shape": self.get_state_shape()}
@abstractmethod
def get_action_shape(self) -> TShape:
@ -33,7 +31,7 @@ class Environments(ABC):
class ContinuousEnvironments(Environments):
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
@ -44,12 +42,12 @@ class ContinuousEnvironments(Environments):
@staticmethod
def _get_continuous_env_info(
env: gym.Env,
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
env: gym.Env,
) -> tuple[tuple[int, ...], tuple[int, ...], float]:
if not isinstance(env.action_space, gym.spaces.Box):
raise ValueError(
"Only environments with continuous action space are supported here. "
f"But got env with action space: {env.action_space.__class__}."
f"But got env with action space: {env.action_space.__class__}.",
)
state_shape = env.observation_space.shape or env.observation_space.n
if not state_shape:
@ -68,4 +66,4 @@ class ContinuousEnvironments(Environments):
class EnvFactory(ABC):
@abstractmethod
def create_envs(self) -> Environments:
pass
pass

View File

@ -4,7 +4,9 @@ from typing import Generic, TypeVar
import numpy as np
import torch
from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig
from tianshou.config import (
BasicExperimentConfig,
)
from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory
from tianshou.highlevel.env import EnvFactory
@ -17,23 +19,19 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
class RLExperiment(Generic[TPolicy, TTrainer]):
def __init__(self,
config: BasicExperimentConfig,
logger_config: LoggerConfig,
general_config: RLAgentConfig,
sampling_config: RLSamplingConfig,
env_factory: EnvFactory,
logger_factory: LoggerFactory,
agent_factory: AgentFactory):
def __init__(
self,
config: BasicExperimentConfig,
env_factory: EnvFactory,
logger_factory: LoggerFactory,
agent_factory: AgentFactory,
):
self.config = config
self.logger_config = logger_config
self.general_config = general_config
self.sampling_config = sampling_config
self.env_factory = env_factory
self.logger_factory = logger_factory
self.agent_factory = agent_factory
def _set_seed(self):
def _set_seed(self) -> None:
seed = self.config.seed
np.random.seed(seed)
torch.manual_seed(seed)
@ -43,7 +41,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
# TODO
}
def run(self, log_name: str):
def run(self, log_name: str) -> None:
self._set_seed()
envs = self.env_factory.create_envs()
@ -52,25 +50,47 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
full_config.update(envs.info())
run_id = self.config.resume_id
logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config)
logger = self.logger_factory.create_logger(
log_name=log_name,
run_id=run_id,
config_dict=full_config,
)
policy = self.agent_factory.create_policy(envs, self.config.device)
if self.config.resume_path:
self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device)
self.agent_factory.load_checkpoint(
policy,
self.config.resume_path,
envs,
self.config.device,
)
train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs)
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
)
if not self.config.watch:
trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger)
trainer = self.agent_factory.create_trainer(
policy,
train_collector,
test_collector,
envs,
logger,
)
result = trainer.run()
pprint(result) # TODO logging
self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render)
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.config.render,
)
@staticmethod
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render):
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
policy.eval()
test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')

View File

@ -1,15 +1,13 @@
from abc import ABC, abstractmethod
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union, Optional
from torch.utils.tensorboard import SummaryWriter
from tianshou.config import LoggerConfig
from tianshou.utils import TensorboardLogger, WandbLogger
TLogger = Union[TensorboardLogger, WandbLogger]
TLogger = TensorboardLogger | WandbLogger
@dataclass
@ -20,7 +18,7 @@ class Logger:
class LoggerFactory(ABC):
@abstractmethod
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
pass
@ -28,7 +26,7 @@ class DefaultLoggerFactory(LoggerFactory):
def __init__(self, config: LoggerConfig):
self.config = config
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
writer = SummaryWriter(self.config.logdir)
writer.add_text("args", str(self.config))
if self.config.logger == "wandb":

View File

@ -1,24 +1,24 @@
from abc import abstractmethod, ABC
from typing import Sequence
from abc import ABC, abstractmethod
from collections.abc import Sequence
import numpy as np
import torch
from torch import nn
import numpy as np
from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.continuous import Critic as ContinuousCritic
TDevice = str | int | torch.device
def init_linear_orthogonal(m: torch.nn.Module):
"""
Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0
def init_linear_orthogonal(module: torch.nn.Module):
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
:param m: the module whose submodules are to be processed
:param module: the module whose submodules are to be processed
"""
for m in m.modules():
for m in module.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
@ -31,9 +31,9 @@ class ActorFactory(ABC):
@staticmethod
def _init_linear(actor: torch.nn.Module):
"""
Initializes linear layers of an actor module using default mechanisms
:param module: the actor module
"""Initializes linear layers of an actor module using default mechanisms.
:param module: the actor module.
"""
init_linear_orthogonal(actor)
if hasattr(actor, "mu"):
@ -51,17 +51,29 @@ class ContinuousActorFactory(ActorFactory, ABC):
class ContinuousActorProbFactory(ContinuousActorFactory):
def __init__(self, hidden_sizes: Sequence[int]):
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
self.hidden_sizes = hidden_sizes
self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net(
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
envs.get_state_shape(),
hidden_sizes=self.hidden_sizes,
activation=nn.Tanh,
device=device,
)
actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device)
actor = ActorProb(
net_a,
envs.get_action_shape(),
unbounded=True,
device=device,
conditioned_sigma=self.conditioned_sigma,
).to(device)
# init params
torch.nn.init.constant_(actor.sigma_param, -0.5)
if not self.conditioned_sigma:
torch.nn.init.constant_(actor.sigma_param, -0.5)
self._init_linear(actor)
return actor
@ -69,7 +81,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
class CriticFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
@ -78,12 +90,19 @@ class ContinuousCriticFactory(CriticFactory, ABC):
class ContinuousNetCriticFactory(ContinuousCriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
def __init__(self, hidden_sizes: Sequence[int], action_shape=0):
self.action_shape = action_shape
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
envs.get_state_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = ContinuousCritic(net_c, device=device).to(device)
init_linear_orthogonal(critic)

View File

@ -1,54 +1,51 @@
from abc import ABC, abstractmethod
from typing import Union, Iterable, Dict, Any, Optional
from collections.abc import Iterable
from typing import Any, Type
import numpy as np
import torch
from torch import Tensor
from torch.optim import Adam
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.config import RLSamplingConfig, NNConfig
from tianshou.config import RLSamplingConfig
TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
class OptimizerFactory(ABC):
@abstractmethod
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
pass
class TorchOptimizerFactory(OptimizerFactory):
def __init__(self, optim_class, **kwargs):
def __init__(self, optim_class: Any, **kwargs):
self.optim_class = optim_class
self.kwargs = kwargs
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), **self.kwargs)
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
class AdamOptimizerFactory(OptimizerFactory):
def __init__(self, lr):
self.lr = lr
def create_optimizer(self, module: torch.nn.Module) -> Adam:
return Adam(module.parameters(), lr=self.lr)
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
return Adam(module.parameters(), lr=lr)
class LRSchedulerFactory(ABC):
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass
class LinearLRSchedulerFactory(LRSchedulerFactory):
def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig):
self.nn_config = nn_config
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
lr_scheduler = None
if self.nn_config.lr_decay:
max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
return lr_scheduler
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
max_update_num = (
np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
* self.sampling_config.num_epochs
)
return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)