2023-10-23 15:45:14 +02:00

462 lines
15 KiB
Python

import math
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
## Fully connected Neural Network block - Multi Layer Perceptron
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=16, n_layers=1, act='relu', batch_norm=True):
super(MLP, self).__init__()
activations = {'relu': nn.ReLU, 'sigmoid': nn.Sigmoid, 'tanh': nn.Tanh, 'leaky_relu': nn.LeakyReLU,
'elu': nn.ELU, 'prelu': nn.PReLU, 'softplus': nn.Softplus, 'mish': nn.Mish,
'identity': nn.Identity
}
act_func = activations[act]
layers = [nn.Linear(in_dim, hidden_dim), act_func()]
for i in range(n_layers):
layers += [
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim) if batch_norm else nn.Identity(),
act_func(),
]
layers.append(nn.Linear(hidden_dim, out_dim))
self._network = nn.Sequential(
*layers
)
def forward(self, x):
return self._network(x)
# Resnet Blocks
class ResnetBlockFC(nn.Module):
'''
Fully connected ResNet Block class.
Args:
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
'''
def __init__(self, size_in, size_out=None, size_h=None):
super().__init__()
# Attributes
if size_out is None:
size_out = size_in
if size_h is None:
size_h = min(size_in, size_out)
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
self.fc_0 = nn.Linear(size_in, size_h)
self.fc_1 = nn.Linear(size_h, size_out)
self.actvn = nn.ReLU()
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Linear(size_in, size_out, bias=False)
# Initialization
nn.init.zeros_(self.fc_1.weight)
def forward(self, x):
net = self.fc_0(self.actvn(x))
dx = self.fc_1(self.actvn(net))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=30.):
super().__init__()
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# https://gist.github.com/kevinzakka/dd9fa5177cda13593524f4d71eb38ad5
class SpatialSoftArgmax(nn.Module):
"""Spatial softmax as defined in [1].
Concretely, the spatial softmax of each feature
map is used to compute a weighted mean of the pixel
locations, effectively performing a soft arg-max
over the feature dimension.
References:
[1]: End-to-End Training of Deep Visuomotor Policies,
https://arxiv.org/abs/1504.00702
"""
def __init__(self, normalize=False):
"""Constructor.
Args:
normalize (bool): Whether to use normalized
image coordinates, i.e. coordinates in
the range `[-1, 1]`.
"""
super().__init__()
self.normalize = normalize
self.temperatur = nn.Parameter(torch.ones(1), requires_grad=True)
def _coord_grid(self, h, w, device):
if self.normalize:
return torch.stack(
torch.meshgrid(
torch.linspace(-1, 1, w, device=device),
torch.linspace(-1, 1, h, device=device),
)
)
return torch.stack(
torch.meshgrid(
torch.arange(0, w, device=device),
torch.arange(0, h, device=device),
)
)
def forward(self, x):
assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)."
# compute a spatial softmax over the input:
# given an input of shape (B, C, H, W),
# reshape it to (B*C, H*W) then apply
# the softmax operator over the last dimension
b, c, h, w = x.shape
# x = x * h * w
x = x * (h * w / self.temperatur)
# print(self.temperatur)
softmax = F.softmax(x.view(-1, h * w), dim=-1)
# create a meshgrid of pixel coordinates
# both in the x and y axes
xc, yc = self._coord_grid(h, w, x.device)
# element-wise multiply the x and y coordinates
# with the softmax, then sum over the h*w dimension
# this effectively computes the weighted mean of x
# and y locations
y_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True)
x_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True)
# concatenate and reshape the result
# to (B, C*2) where for every feature
# we have the expected x and y pixel
# locations
return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2)
########################################################################################################################
# Modules Temporal Unet
########################################################################################################################
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1))
def forward(self, x):
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv1d(hidden_dim, dim, 1)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: einops.rearrange(t, 'b (h c) d -> b h c d', h=self.heads), qkv)
q = q * self.scale
k = k.softmax(dim = -1)
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = einops.rearrange(out, 'b h c d -> b (h c) d')
return self.to_out(out)
class TimeEncoder(nn.Module):
def __init__(self, dim, dim_out):
super().__init__()
self.encoder = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim_out)
)
def forward(self, x):
return self.encoder(x)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, kernel_size=3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, kernel_size=4, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, padding=None, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, stride=1,
padding=padding if padding is not None else kernel_size // 2),
Rearrange('batch channels n_support_points -> batch channels 1 n_support_points'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 n_support_points -> batch channels n_support_points'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ResidualBlock(nn.Module):
################
# Janner code
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels))
self.downsample = downsample
self.relu = nn.ReLU()
self.out_channels = out_channels
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, cond_embed_dim, n_support_points, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(inp_channels, out_channels, kernel_size, n_groups=group_norm_n_groups(out_channels)),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=group_norm_n_groups(out_channels)),
])
# Without context conditioning, cond_mlp handles only time embeddings
self.cond_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(cond_embed_dim, out_channels),
Rearrange('batch t -> batch t 1'),
)
self.residual_conv = nn.Conv1d(inp_channels, out_channels, kernel_size=1, stride=1, padding=0) \
if inp_channels != out_channels else nn.Identity()
def forward(self, x, c):
'''
x : [ batch_size x inp_channels x n_support_points ]
c : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x n_support_points ]
'''
h = self.blocks[0](x) + self.cond_mlp(c)
h = self.blocks[1](h)
res = self.residual_conv(x)
out = h + res
return out
class TemporalBlockMLP(nn.Module):
def __init__(self, inp_channels, out_channels, cond_embed_dim):
super().__init__()
self.blocks = nn.ModuleList([
MLP(inp_channels, out_channels, hidden_dim=out_channels, n_layers=0, act='mish')
])
# Without context conditioning, cond_mlp handles only time embeddings
self.cond_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(cond_embed_dim, out_channels),
# Rearrange('batch t -> batch t 1'),
)
self.last_act = nn.Mish()
def forward(self, x, c):
'''
x : [ batch_size x inp_channels x n_support_points ]
c : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x n_support_points ]
'''
h = self.blocks[0](x) + self.cond_mlp(c)
out = self.last_act(h)
return out
def group_norm_n_groups(n_channels, target_n_groups=8):
if n_channels < target_n_groups:
return 1
for n_groups in range(target_n_groups, target_n_groups + 10):
if n_channels % n_groups == 0:
return n_groups
return 1
def compute_padding_conv1d(L, KSZ, S, D, deconv=False):
'''
https://gist.github.com/AhmadMoussa/d32c41c11366440bc5eaf4efb48a2e73
:param L: Input length (or width)
:param KSZ: Kernel size (or width)
:param S: Stride
:param D: Dilation Factor
:param deconv: True if ConvTranspose1d
:return: Returns padding such that output width is exactly half of input width
'''
print(f"INPUT SIZE {L}")
if not deconv:
return math.ceil((S * (L / 2) - L + D * (KSZ - 1) - 1) / 2)
else:
print(L, S, D, KSZ)
pad = math.ceil(((L - 1) * S + D * (KSZ - 1) + 1 - L * 2) / 2)
print("PAD", pad)
output_size = (L - 1) * S - 2 * pad + D * (KSZ - 1) + 1
print("OUTPUT SIZE", output_size)
return pad
def compute_output_length_maxpool1d(L, KSZ, S, D, P):
'''
https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html
:param L: Input length (or width)
:param KSZ: Kernel size (or width)
:param S: Stride
:param D: Dilation Factor
:param P: Padding
'''
return math.floor((L + 2 * P - D * (KSZ - 1) - 1) / S + 1)
if __name__ == "__main__":
b, c, h, w = 1, 64, 12, 12
x = torch.full((b, c, h, w), 0.00)
i_max = 4
true_max = torch.randint(0, 10, size=(b, c, 2))
for i in range(b):
for j in range(c):
x[i, j, true_max[i, j, 0], true_max[i, j, 1]] = 1000
# x[i, j, i_max, true_max] = 1
# x[0,0,0,0] = 1000
soft_max = SpatialSoftArgmax(normalize=True)(x)
soft_max2 = SpatialSoftArgmax(normalize=False)(x)
diff = soft_max - (soft_max2 / 12) * 2 - 1
resh = soft_max.reshape(b, c, 2)
assert torch.allclose(true_max.float(), resh)
exit()
test_scales = [1, 5, 10, 30, 50, 70, 100]
for scale in test_scales[::-1]:
i_max = 4
true_max = torch.randint(0, 10, size=(b, c, 2))
for i in range(b):
for j in range(c):
x[i, j, true_max[i, j, 0], true_max[i, j, 1]] = scale
# x[i, j, i_max, true_max] = 1
# x[0,0,0,0] = 1000
soft_max = SpatialSoftArgmax(normalize=False)(x)
resh = soft_max.reshape(b, c, 2)
assert torch.allclose(true_max.float(), resh), scale