365 lines
12 KiB
Python
Raw Normal View History

2023-10-23 15:45:14 +02:00
import einops
import numpy as np
import torch
import torch.nn as nn
from abc import ABC
from einops import rearrange
from torch.nn import DataParallel
from mpd.models.layers.layers import GaussianFourierProjection, Downsample1d, Conv1dBlock, Upsample1d, \
ResidualTemporalBlock, TimeEncoder, MLP, group_norm_n_groups, LinearAttention, PreNorm, Residual, TemporalBlockMLP
from mpd.models.layers.layers_attention import SpatialTransformer
UNET_DIM_MULTS = {
0: (1, 2, 4),
1: (1, 2, 4, 8)
}
class TemporalUnet(nn.Module):
def __init__(
self,
n_support_points=None,
state_dim=None,
unet_input_dim=32,
dim_mults=(1, 2, 4, 8),
time_emb_dim=32,
self_attention=False,
conditioning_embed_dim=4,
conditioning_type=None,
attention_num_heads=2,
attention_dim_head=32,
**kwargs
):
super().__init__()
self.state_dim = state_dim
input_dim = state_dim
# Conditioning
if conditioning_type is None or conditioning_type == 'None':
conditioning_type = None
elif conditioning_type == 'concatenate':
if self.state_dim < conditioning_embed_dim // 4:
# Embed the state in a latent space HxF if the conditioning embedding is much larger than the state
state_emb_dim = conditioning_embed_dim // 4
self.state_encoder = MLP(state_dim, state_emb_dim, hidden_dim=state_emb_dim//2, n_layers=1, act='mish')
else:
state_emb_dim = state_dim
self.state_encoder = nn.Identity()
input_dim = state_emb_dim + conditioning_embed_dim
elif conditioning_type == 'attention':
pass
elif conditioning_type == 'default':
pass
else:
raise NotImplementedError
self.conditioning_type = conditioning_type
dims = [input_dim, *map(lambda m: unet_input_dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
# Networks
self.time_mlp = TimeEncoder(32, time_emb_dim)
# conditioning dimension (time + context)
cond_dim = time_emb_dim + (conditioning_embed_dim if conditioning_type == 'default' else 0)
# Unet
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, cond_dim, n_support_points=n_support_points),
ResidualTemporalBlock(dim_out, dim_out, cond_dim, n_support_points=n_support_points),
Residual(PreNorm(dim_out, LinearAttention(dim_out))) if self_attention else nn.Identity(),
SpatialTransformer(dim_out, attention_num_heads, attention_dim_head, depth=1,
context_dim=conditioning_embed_dim) if conditioning_type == 'attention' else None,
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
if not is_last:
n_support_points = n_support_points // 2
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, cond_dim, n_support_points=n_support_points)
self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if self_attention else nn.Identity()
self.mid_attention = SpatialTransformer(mid_dim, attention_num_heads, attention_dim_head, depth=1,
context_dim=conditioning_embed_dim) if conditioning_type == 'attention' else nn.Identity()
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, cond_dim, n_support_points=n_support_points)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, cond_dim, n_support_points=n_support_points),
ResidualTemporalBlock(dim_in, dim_in, cond_dim, n_support_points=n_support_points),
Residual(PreNorm(dim_in, LinearAttention(dim_in))) if self_attention else nn.Identity(),
SpatialTransformer(dim_in, attention_num_heads, attention_dim_head, depth=1,
context_dim=conditioning_embed_dim) if conditioning_type == 'attention' else None,
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
if not is_last:
n_support_points = n_support_points * 2
self.final_conv = nn.Sequential(
Conv1dBlock(unet_input_dim, unet_input_dim, kernel_size=5, n_groups=group_norm_n_groups(unet_input_dim)),
nn.Conv1d(unet_input_dim, state_dim, 1),
)
def forward(self, x, time, context):
"""
x : [ batch x horizon x state_dim ]
context: [batch x context_dim]
"""
b, h, d = x.shape
t_emb = self.time_mlp(time)
c_emb = t_emb
if self.conditioning_type == 'concatenate':
x_emb = self.state_encoder(x)
context = einops.repeat(context, 'm n -> m h n', h=h)
x = torch.cat((x_emb, context), dim=-1)
elif self.conditioning_type == 'attention':
# reshape to keep the interface
context = einops.rearrange(context, 'b d -> b 1 d')
elif self.conditioning_type == 'default':
c_emb = torch.cat((t_emb, context), dim=-1)
# swap horizon and channels (state_dim)
x = einops.rearrange(x, 'b h c -> b c h') # batch, horizon, channels (state_dim)
h = []
for resnet, resnet2, attn_self, attn_conditioning, downsample in self.downs:
x = resnet(x, c_emb)
# if self.conditioning_type == 'attention':
# x = attention1(x, context=conditioning_emb)
x = resnet2(x, c_emb)
x = attn_self(x)
if self.conditioning_type == 'attention':
x = attn_conditioning(x, context=context)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, c_emb)
x = self.mid_attn(x)
if self.conditioning_type == 'attention':
x = self.mid_attention(x, context=context)
x = self.mid_block2(x, c_emb)
for resnet, resnet2, attn_self, attn_conditioning, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, c_emb)
x = resnet2(x, c_emb)
x = attn_self(x)
if self.conditioning_type == 'attention':
x = attn_conditioning(x, context=context)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, 'b c h -> b h c')
return x
class EnvModel(nn.Module):
def __init__(
self,
in_dim=16,
out_dim=16,
**kwargs
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.net = nn.Identity()
def forward(self, input_d):
env = input_d['env']
env_emb = self.net(env)
return env_emb
class TaskModel(nn.Module):
def __init__(
self,
in_dim=16,
out_dim=32,
**kwargs
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.net = nn.Identity()
def forward(self, input_d):
task = input_d['tasks']
task_emb = self.net(task)
return task_emb
class TaskModelNew(nn.Module):
def __init__(
self,
in_dim=16,
out_dim=32,
**kwargs
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.net = nn.Identity()
def forward(self, task):
task_emb = self.net(task)
return task_emb
class ContextModel(nn.Module):
def __init__(
self,
env_model=None,
task_model=None,
out_dim=32,
**kwargs
):
super().__init__()
self.env_model = env_model
self.task_model = task_model
self.in_dim = self.env_model.out_dim + self.task_model.out_dim
# self.out_dim = out_dim
# self.net = MLP(self.in_dim, self.out_dim, hidden_dim=out_dim, n_layers=1, act='mish')
self.out_dim = self.in_dim
self.net = nn.Identity()
def forward(self, input_d=None):
if input_d is None:
return None
env_emb = self.env_model(input_d)
task_emb = self.task_model(input_d)
context = torch.cat((env_emb, task_emb), dim=-1)
context_emb = self.net(context)
return context_emb
class PointUnet(nn.Module):
def __init__(
self,
n_support_points=None,
state_dim=None,
dim=32,
dim_mults=(1, 2, 4),
time_emb_dim=32,
conditioning_embed_dim=4,
conditioning_type=None,
**kwargs
):
super().__init__()
self.dim_mults = dim_mults
self.state_dim = state_dim
input_dim = state_dim
# Conditioning
if conditioning_type is None or conditioning_type == 'None':
conditioning_type = None
elif conditioning_type == 'concatenate':
if self.state_dim < conditioning_embed_dim // 4:
# Embed the state in a latent space HxF if the conditioning embedding is much larger than the state
state_emb_dim = conditioning_embed_dim // 4
self.state_encoder = MLP(state_dim, state_emb_dim, hidden_dim=state_emb_dim//2, n_layers=1, act='mish')
else:
state_emb_dim = state_dim
self.state_encoder = nn.Identity()
input_dim = state_emb_dim + conditioning_embed_dim
elif conditioning_type == 'default':
pass
else:
raise NotImplementedError
self.conditioning_type = conditioning_type
dims = [input_dim, *map(lambda m: dim * m, self.dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
# Networks
self.time_mlp = TimeEncoder(32, time_emb_dim)
# conditioning dimension (time + context)
cond_dim = time_emb_dim + (conditioning_embed_dim if conditioning_type == 'default' else 0)
# Unet
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
self.downs.append(nn.ModuleList([
TemporalBlockMLP(dim_in, dim_out, cond_dim)
]))
mid_dim = dims[-1]
self.mid_block1 = TemporalBlockMLP(mid_dim, mid_dim, cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(nn.ModuleList([
TemporalBlockMLP(dim_out * 2, dim_in, cond_dim)
]))
self.final_layer = nn.Sequential(
MLP(dim, state_dim, hidden_dim=dim, n_layers=0, act='identity')
)
def forward(self, x, time, context):
"""
x : [ batch x horizon x state_dim ]
context: [batch x context_dim]
"""
x = einops.rearrange(x, 'b 1 d -> b d')
t_emb = self.time_mlp(time)
c_emb = t_emb
if self.conditioning_type == 'concatenate':
x_emb = self.state_encoder(x)
x = torch.cat((x_emb, context), dim=-1)
elif self.conditioning_type == 'default':
c_emb = torch.cat((t_emb, context), dim=-1)
h = []
for resnet, in self.downs:
x = resnet(x, c_emb)
h.append(x)
x = self.mid_block1(x, c_emb)
for resnet, in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, c_emb)
x = self.final_layer(x)
x = einops.rearrange(x, 'b d -> b 1 d')
return x