365 lines
12 KiB
Python
365 lines
12 KiB
Python
import einops
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from abc import ABC
|
|
from einops import rearrange
|
|
from torch.nn import DataParallel
|
|
|
|
from mpd.models.layers.layers import GaussianFourierProjection, Downsample1d, Conv1dBlock, Upsample1d, \
|
|
ResidualTemporalBlock, TimeEncoder, MLP, group_norm_n_groups, LinearAttention, PreNorm, Residual, TemporalBlockMLP
|
|
from mpd.models.layers.layers_attention import SpatialTransformer
|
|
|
|
|
|
UNET_DIM_MULTS = {
|
|
0: (1, 2, 4),
|
|
1: (1, 2, 4, 8)
|
|
}
|
|
|
|
|
|
class TemporalUnet(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
n_support_points=None,
|
|
state_dim=None,
|
|
unet_input_dim=32,
|
|
dim_mults=(1, 2, 4, 8),
|
|
time_emb_dim=32,
|
|
self_attention=False,
|
|
conditioning_embed_dim=4,
|
|
conditioning_type=None,
|
|
attention_num_heads=2,
|
|
attention_dim_head=32,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
|
|
self.state_dim = state_dim
|
|
input_dim = state_dim
|
|
|
|
# Conditioning
|
|
if conditioning_type is None or conditioning_type == 'None':
|
|
conditioning_type = None
|
|
elif conditioning_type == 'concatenate':
|
|
if self.state_dim < conditioning_embed_dim // 4:
|
|
# Embed the state in a latent space HxF if the conditioning embedding is much larger than the state
|
|
state_emb_dim = conditioning_embed_dim // 4
|
|
self.state_encoder = MLP(state_dim, state_emb_dim, hidden_dim=state_emb_dim//2, n_layers=1, act='mish')
|
|
else:
|
|
state_emb_dim = state_dim
|
|
self.state_encoder = nn.Identity()
|
|
input_dim = state_emb_dim + conditioning_embed_dim
|
|
elif conditioning_type == 'attention':
|
|
pass
|
|
elif conditioning_type == 'default':
|
|
pass
|
|
else:
|
|
raise NotImplementedError
|
|
self.conditioning_type = conditioning_type
|
|
|
|
dims = [input_dim, *map(lambda m: unet_input_dim * m, dim_mults)]
|
|
in_out = list(zip(dims[:-1], dims[1:]))
|
|
print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
|
|
|
# Networks
|
|
self.time_mlp = TimeEncoder(32, time_emb_dim)
|
|
|
|
# conditioning dimension (time + context)
|
|
cond_dim = time_emb_dim + (conditioning_embed_dim if conditioning_type == 'default' else 0)
|
|
|
|
# Unet
|
|
self.downs = nn.ModuleList([])
|
|
self.ups = nn.ModuleList([])
|
|
num_resolutions = len(in_out)
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
|
is_last = ind >= (num_resolutions - 1)
|
|
|
|
self.downs.append(nn.ModuleList([
|
|
ResidualTemporalBlock(dim_in, dim_out, cond_dim, n_support_points=n_support_points),
|
|
ResidualTemporalBlock(dim_out, dim_out, cond_dim, n_support_points=n_support_points),
|
|
Residual(PreNorm(dim_out, LinearAttention(dim_out))) if self_attention else nn.Identity(),
|
|
SpatialTransformer(dim_out, attention_num_heads, attention_dim_head, depth=1,
|
|
context_dim=conditioning_embed_dim) if conditioning_type == 'attention' else None,
|
|
Downsample1d(dim_out) if not is_last else nn.Identity()
|
|
]))
|
|
|
|
if not is_last:
|
|
n_support_points = n_support_points // 2
|
|
|
|
mid_dim = dims[-1]
|
|
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, cond_dim, n_support_points=n_support_points)
|
|
self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if self_attention else nn.Identity()
|
|
self.mid_attention = SpatialTransformer(mid_dim, attention_num_heads, attention_dim_head, depth=1,
|
|
context_dim=conditioning_embed_dim) if conditioning_type == 'attention' else nn.Identity()
|
|
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, cond_dim, n_support_points=n_support_points)
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
|
is_last = ind >= (num_resolutions - 1)
|
|
|
|
self.ups.append(nn.ModuleList([
|
|
ResidualTemporalBlock(dim_out * 2, dim_in, cond_dim, n_support_points=n_support_points),
|
|
ResidualTemporalBlock(dim_in, dim_in, cond_dim, n_support_points=n_support_points),
|
|
Residual(PreNorm(dim_in, LinearAttention(dim_in))) if self_attention else nn.Identity(),
|
|
SpatialTransformer(dim_in, attention_num_heads, attention_dim_head, depth=1,
|
|
context_dim=conditioning_embed_dim) if conditioning_type == 'attention' else None,
|
|
Upsample1d(dim_in) if not is_last else nn.Identity()
|
|
]))
|
|
|
|
if not is_last:
|
|
n_support_points = n_support_points * 2
|
|
|
|
self.final_conv = nn.Sequential(
|
|
Conv1dBlock(unet_input_dim, unet_input_dim, kernel_size=5, n_groups=group_norm_n_groups(unet_input_dim)),
|
|
nn.Conv1d(unet_input_dim, state_dim, 1),
|
|
)
|
|
|
|
def forward(self, x, time, context):
|
|
"""
|
|
x : [ batch x horizon x state_dim ]
|
|
context: [batch x context_dim]
|
|
"""
|
|
b, h, d = x.shape
|
|
|
|
t_emb = self.time_mlp(time)
|
|
c_emb = t_emb
|
|
if self.conditioning_type == 'concatenate':
|
|
x_emb = self.state_encoder(x)
|
|
context = einops.repeat(context, 'm n -> m h n', h=h)
|
|
x = torch.cat((x_emb, context), dim=-1)
|
|
elif self.conditioning_type == 'attention':
|
|
# reshape to keep the interface
|
|
context = einops.rearrange(context, 'b d -> b 1 d')
|
|
elif self.conditioning_type == 'default':
|
|
c_emb = torch.cat((t_emb, context), dim=-1)
|
|
|
|
# swap horizon and channels (state_dim)
|
|
x = einops.rearrange(x, 'b h c -> b c h') # batch, horizon, channels (state_dim)
|
|
|
|
h = []
|
|
for resnet, resnet2, attn_self, attn_conditioning, downsample in self.downs:
|
|
x = resnet(x, c_emb)
|
|
# if self.conditioning_type == 'attention':
|
|
# x = attention1(x, context=conditioning_emb)
|
|
x = resnet2(x, c_emb)
|
|
x = attn_self(x)
|
|
if self.conditioning_type == 'attention':
|
|
x = attn_conditioning(x, context=context)
|
|
h.append(x)
|
|
x = downsample(x)
|
|
|
|
x = self.mid_block1(x, c_emb)
|
|
x = self.mid_attn(x)
|
|
if self.conditioning_type == 'attention':
|
|
x = self.mid_attention(x, context=context)
|
|
x = self.mid_block2(x, c_emb)
|
|
|
|
for resnet, resnet2, attn_self, attn_conditioning, upsample in self.ups:
|
|
x = torch.cat((x, h.pop()), dim=1)
|
|
x = resnet(x, c_emb)
|
|
x = resnet2(x, c_emb)
|
|
x = attn_self(x)
|
|
if self.conditioning_type == 'attention':
|
|
x = attn_conditioning(x, context=context)
|
|
x = upsample(x)
|
|
|
|
x = self.final_conv(x)
|
|
|
|
x = einops.rearrange(x, 'b c h -> b h c')
|
|
|
|
return x
|
|
|
|
|
|
class EnvModel(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_dim=16,
|
|
out_dim=16,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
|
|
self.net = nn.Identity()
|
|
|
|
def forward(self, input_d):
|
|
env = input_d['env']
|
|
env_emb = self.net(env)
|
|
return env_emb
|
|
|
|
|
|
class TaskModel(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_dim=16,
|
|
out_dim=32,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
|
|
self.net = nn.Identity()
|
|
|
|
def forward(self, input_d):
|
|
task = input_d['tasks']
|
|
task_emb = self.net(task)
|
|
return task_emb
|
|
|
|
|
|
class TaskModelNew(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_dim=16,
|
|
out_dim=32,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
|
|
self.net = nn.Identity()
|
|
|
|
def forward(self, task):
|
|
task_emb = self.net(task)
|
|
return task_emb
|
|
|
|
|
|
class ContextModel(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
env_model=None,
|
|
task_model=None,
|
|
out_dim=32,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
|
|
self.env_model = env_model
|
|
self.task_model = task_model
|
|
|
|
self.in_dim = self.env_model.out_dim + self.task_model.out_dim
|
|
|
|
# self.out_dim = out_dim
|
|
# self.net = MLP(self.in_dim, self.out_dim, hidden_dim=out_dim, n_layers=1, act='mish')
|
|
|
|
self.out_dim = self.in_dim
|
|
self.net = nn.Identity()
|
|
|
|
def forward(self, input_d=None):
|
|
if input_d is None:
|
|
return None
|
|
env_emb = self.env_model(input_d)
|
|
task_emb = self.task_model(input_d)
|
|
context = torch.cat((env_emb, task_emb), dim=-1)
|
|
context_emb = self.net(context)
|
|
return context_emb
|
|
|
|
|
|
class PointUnet(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
n_support_points=None,
|
|
state_dim=None,
|
|
dim=32,
|
|
dim_mults=(1, 2, 4),
|
|
time_emb_dim=32,
|
|
conditioning_embed_dim=4,
|
|
conditioning_type=None,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
|
|
self.dim_mults = dim_mults
|
|
|
|
self.state_dim = state_dim
|
|
input_dim = state_dim
|
|
|
|
# Conditioning
|
|
if conditioning_type is None or conditioning_type == 'None':
|
|
conditioning_type = None
|
|
elif conditioning_type == 'concatenate':
|
|
if self.state_dim < conditioning_embed_dim // 4:
|
|
# Embed the state in a latent space HxF if the conditioning embedding is much larger than the state
|
|
state_emb_dim = conditioning_embed_dim // 4
|
|
self.state_encoder = MLP(state_dim, state_emb_dim, hidden_dim=state_emb_dim//2, n_layers=1, act='mish')
|
|
else:
|
|
state_emb_dim = state_dim
|
|
self.state_encoder = nn.Identity()
|
|
input_dim = state_emb_dim + conditioning_embed_dim
|
|
elif conditioning_type == 'default':
|
|
pass
|
|
else:
|
|
raise NotImplementedError
|
|
self.conditioning_type = conditioning_type
|
|
|
|
dims = [input_dim, *map(lambda m: dim * m, self.dim_mults)]
|
|
in_out = list(zip(dims[:-1], dims[1:]))
|
|
print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
|
|
|
# Networks
|
|
self.time_mlp = TimeEncoder(32, time_emb_dim)
|
|
|
|
# conditioning dimension (time + context)
|
|
cond_dim = time_emb_dim + (conditioning_embed_dim if conditioning_type == 'default' else 0)
|
|
|
|
# Unet
|
|
self.downs = nn.ModuleList([])
|
|
self.ups = nn.ModuleList([])
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
|
self.downs.append(nn.ModuleList([
|
|
TemporalBlockMLP(dim_in, dim_out, cond_dim)
|
|
]))
|
|
|
|
mid_dim = dims[-1]
|
|
self.mid_block1 = TemporalBlockMLP(mid_dim, mid_dim, cond_dim)
|
|
|
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
|
self.ups.append(nn.ModuleList([
|
|
TemporalBlockMLP(dim_out * 2, dim_in, cond_dim)
|
|
]))
|
|
|
|
self.final_layer = nn.Sequential(
|
|
MLP(dim, state_dim, hidden_dim=dim, n_layers=0, act='identity')
|
|
)
|
|
|
|
def forward(self, x, time, context):
|
|
"""
|
|
x : [ batch x horizon x state_dim ]
|
|
context: [batch x context_dim]
|
|
"""
|
|
x = einops.rearrange(x, 'b 1 d -> b d')
|
|
|
|
t_emb = self.time_mlp(time)
|
|
c_emb = t_emb
|
|
if self.conditioning_type == 'concatenate':
|
|
x_emb = self.state_encoder(x)
|
|
x = torch.cat((x_emb, context), dim=-1)
|
|
elif self.conditioning_type == 'default':
|
|
c_emb = torch.cat((t_emb, context), dim=-1)
|
|
|
|
h = []
|
|
for resnet, in self.downs:
|
|
x = resnet(x, c_emb)
|
|
h.append(x)
|
|
|
|
x = self.mid_block1(x, c_emb)
|
|
|
|
for resnet, in self.ups:
|
|
x = torch.cat((x, h.pop()), dim=1)
|
|
x = resnet(x, c_emb)
|
|
|
|
x = self.final_layer(x)
|
|
|
|
x = einops.rearrange(x, 'b d -> b 1 d')
|
|
|
|
return x
|