import torch import torch.nn as nn import time import torch.nn.functional as F from torch.autograd import Variable from einops import rearrange import numpy as np from Layers import EncoderLayer from einops.layers.torch import Rearrange def filter(opState, kernelsize=5): Bs = opState.shape[0] ches = opState.shape[1] recL = int((kernelsize-3)/2) labelTable = torch.zeros(Bs, int(ches+2*recL),opState.shape[2]).cuda() labelTable[:,:recL,:] = opState[:,0,:].unsqueeze(dim=1) labelTable[:,-recL:,:] = opState[:,-1,:].unsqueeze(dim=1) labelTable[:,recL:-recL,:] = opState newOpState = torch.zeros_like(opState) tmpT = labelTable.unfold(1, kernelsize, 1) tmpMeanT = torch.mean(tmpT, dim=-1) newOpState[:,1:-1,:] = tmpMeanT newOpState[:,0,:] = opState[:,0,:] newOpState[:,-1,:] = opState[:,-1,:] return newOpState class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class block(nn.Module): def __init__( self, in_channels, intermediate_channels, identity_downsample=None, stride=1 ): super().__init__() self.expansion = 4 self.conv1 = nn.Conv2d( in_channels, intermediate_channels, kernel_size=1, stride=1, padding=0, bias=False, ) self.bn1 = nn.BatchNorm2d(intermediate_channels) self.conv2 = nn.Conv2d( intermediate_channels, intermediate_channels, kernel_size=3, stride=stride, padding=1, bias=False, ) self.bn2 = nn.BatchNorm2d(intermediate_channels) self.conv3 = nn.Conv2d( intermediate_channels, intermediate_channels * self.expansion, kernel_size=1, stride=1, padding=0, bias=False, ) self.bn3 = nn.BatchNorm2d(intermediate_channels * self.expansion) self.relu = nn.ReLU() self.identity_downsample = identity_downsample self.stride = stride def forward(self, x): identity = x.clone() x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.conv3(x) x = self.bn3(x) if self.identity_downsample is not None: identity = self.identity_downsample(identity) x += identity x = self.relu(x) return x class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels): super().__init__() self.double_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) class PositionalEncoding(nn.Module): '''Positional encoding ''' def __init__(self, d_hid, n_position, train_shape): ''' Intialize the Encoder. :param d_hid: Dimesion of the attention features. :param n_position: Number of positions to consider. :param train_shape: The 2D shape of the training model. ''' super(PositionalEncoding, self).__init__() self.n_pos_sqrt = int(np.sqrt(n_position)) self.train_shape = train_shape # Not a parameter self.register_buffer('hashIndex', self._get_hash_table(n_position)) self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) self.register_buffer('pos_table_train', self._get_sinusoid_encoding_table_train(n_position, train_shape)) def _get_hash_table(self, n_position): ''' A simple table converting 1D indexes to 2D grid. :param n_position: The number of positions on the grid. ''' return rearrange(torch.arange(n_position), '(h w) -> h w', h=int(np.sqrt(n_position)), w=int(np.sqrt(n_position))) # 40 * 40 def _get_sinusoid_encoding_table(self, n_position, d_hid): ''' Sinusoid position encoding table. :param n_position: :param d_hid: :returns ''' # TODO: make it with torch instead of numpy def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table[None,:]) def _get_sinusoid_encoding_table_train(self, n_position, train_shape): ''' The encoding table to use for training. NOTE: It is assumed that all training data comes from a fixed map. NOTE: Another assumption that is made is that the training maps are square. :param n_position: The maximum number of positions on the table. :param train_shape: The 2D dimension of the training maps. ''' selectIndex = rearrange(self.hashIndex[:train_shape[0], :train_shape[1]], 'h w -> (h w)') # 24 * 24 return torch.index_select(self.pos_table, dim=1, index=selectIndex) def forward(self, x, conv_shape=None): ''' Callback function :param x: ''' if conv_shape is None: startH, startW = torch.randint(0, self.n_pos_sqrt-self.train_shape[0], (2,)) selectIndex = rearrange( self.hashIndex[startH:startH+self.train_shape[0], startW:startW+self.train_shape[1]], 'h w -> (h w)' ) return x + torch.index_select(self.pos_table, dim=1, index=selectIndex).clone().detach() # assert x.shape[0]==1, "Only valid for testing single image sizes" selectIndex = rearrange(self.hashIndex[:conv_shape[0], :conv_shape[1]], 'h w -> (h w)') return x + self.pos_table[:, selectIndex.long(), :] class Encoder(nn.Module): ''' The encoder of the planner. ''' def __init__(self, n_layers, n_heads, d_k, d_v, d_model, d_inner, pad_idx, n_position, train_shape): ''' Intialize the encoder. :param n_layers: Number of layers of attention and fully connected layer. :param n_heads: Number of self attention modules. :param d_k: Dimension of each Key. :param d_v: Dimension of each Value. :param d_model: Dimension of input/output of encoder layer. :param d_inner: Dimension of the hidden layers of position wise FFN :param pad_idx: TODO .... :param dropout: The value to the dropout argument. :param n_position: Total number of patches the model can handle. :param train_shape: The shape of the output of the patch encodings. ''' super().__init__() self.to_patch_embedding = nn.Sequential( (DoubleConv(4, 64)), nn.MaxPool2d(2), nn.Conv2d(64, 128, kernel_size=3, padding=0, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=0, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Conv2d(128, 256, kernel_size=3, padding=0, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=0, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Conv2d(256, 512, kernel_size=3, padding=0, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True) ) self.reorder_dims = Rearrange('b c h w -> b (h w) c') # Position Encoding. # NOTE: Current setup for adding position encoding after patch Embedding. self.position_enc = PositionalEncoding(d_model, n_position=n_position, train_shape=train_shape) self.layer_stack = nn.ModuleList([ EncoderLayer(d_model, d_inner, n_heads, d_k, d_v) for _ in range(n_layers) ]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) def forward(self, input_map, returns_attns=False): ''' The input of the Encoder should be of dim (b, c, h, w). :param input_map: The input map for planning. :param returns_attns: If True, the model returns slf_attns at each layer ''' enc_slf_attn_list = [] enc_output = self.to_patch_embedding(input_map) conv_map_shape = enc_output.shape[-2:] enc_output = self.reorder_dims(enc_output) if self.training: enc_output = self.position_enc(enc_output) else: enc_output = self.position_enc(enc_output, conv_map_shape) enc_output = self.layer_norm(enc_output) for enc_layer in self.layer_stack: enc_output = enc_layer(enc_output, slf_attn_mask=None) if returns_attns: return enc_output, enc_slf_attn_list return enc_output, class Transformer(nn.Module): ''' A Transformer module ''' def __init__(self, n_layers, n_heads, d_k, d_v, d_model, d_inner, pad_idx, n_position, train_shape): ''' Initialize the Transformer model. :param n_layers: Number of layers of attention and fully connected layers :param n_heads: Number of self attention modules. :param d_k: Dimension of each Key. :param d_v: Dimension of each Value. :param d_model: Dimension of input/output of decoder layer. :param d_inner: Dimension of the hidden layers of position wise FFN1 :param pad_idx: TODO ...... :param dropout: The value of the dropout argument. :param n_position: Dim*dim of the maximum map size. :param train_shape: The shape of the output of the patch encodings. ''' super().__init__() self.encoder = Encoder( n_layers=n_layers, # num of sublayer n_heads=n_heads, # a dimension in query, key, value d_k=d_k, # dimension of key d_v=d_v, # dimension of value d_model=d_model, # channel of conv as a first part d_inner=d_inner, # channel of inner part in the model pad_idx=pad_idx, n_position=n_position, # max table size for position encoding train_shape=train_shape # image size in meters ) def forward(self, input_map): ''' The callback function. :param input_map: :param goal: A 2D torch array representing the goal. :param start: A 2D torch array representing the start. :param cur_index: The current anchor point of patch. ''' enc_output, *_ = self.encoder(input_map) enc_output = rearrange(enc_output, 'b c d -> b d c') enc_output = rearrange(enc_output, 'b c (h w) -> b c h w', h = 20) return enc_output class AnchorNet25(nn.Module): def __init__(self, n_channels, out_channels=1): super(AnchorNet25, self).__init__() model_args = dict( n_layers=6, n_heads=3, d_k=512, d_v=256, d_model=512, d_inner=1024, pad_idx=None, n_position=40*40, # train_shape=[25, 25], train_shape=[20, 20] ) self.transformer = Transformer(**model_args) self.outc = nn.Sequential( nn.Conv2d(512, 1024, kernel_size=3, padding=1,stride=1), nn.BatchNorm2d(1024), nn.ReLU(inplace=True), OutConv(1024, out_channels) ) def forward(self, x): x = self.transformer(x) result = self.outc(x) result_1 = rearrange(result, 'b c h w-> b c (h w)') result_2 = rearrange(result_1, 'b c l-> (b c) l') result = rearrange(result, 'b c h w-> b c (h w)') result = torch.softmax(result, dim=2) result = rearrange(result, 'b c (h w)-> b c h w', h = 20) return x,result,result_2 class trajFCNet(nn.Module): def __init__(self, image_channels=4, pt_num=100, filter = 7, l = 1.2, use_groundTruth = True): super(trajFCNet, self).__init__() self.a = AnchorNet25(image_channels, pt_num) self.fcntrajout = nn.Sequential( nn.Conv2d(512+pt_num, 1024, kernel_size=3, padding=1,stride=1), nn.BatchNorm2d(1024), nn.ReLU(inplace=True), OutConv(1024, pt_num*3) ) self.pt_num = pt_num self.filter = filter self.l = l self.w = (self.l -1.0)/2.0 self.use_groundTruth = use_groundTruth self.sig = nn.Sigmoid()#0~1 xylim = torch.arange(0,20).unsqueeze(dim=1) xylim = xylim.repeat(1,20) yxlim = torch.arange(0,20).unsqueeze(dim=0) yxlim = yxlim.repeat(20,1) self.register_buffer('xylim', xylim) self.register_buffer('yxlim', yxlim) def forward(self, x, labelState, labelRot, anchors): ft, result_1, result_2 = self.a(x) prbmap = torch.zeros_like(result_1) prbmap[:,0,:,:] = anchors[:,0,:,:] prbmap[:,-1,:,:] = anchors[:,-1,:,:] prbmap[:,1:-1,:,:] = result_1[:,1:-1,:,:] resFeature= ft if(self.use_groundTruth): anchorsFeature = anchors else: anchorsFeature = prbmap resInput = torch.cat((anchorsFeature, resFeature), dim=1) resOutput = self.fcntrajout(resInput) if(self.use_groundTruth): if(self.l>0): #hzchzc px = (self.l*self.sig(resOutput[:,0::3,:,:])-self.w)* anchors py = (self.l*self.sig(resOutput[:,1::3,:,:])-self.w)* anchors else: px = resOutput[:,0::3,:,:]* anchors py = resOutput[:,1::3,:,:]* anchors yw = resOutput[:,2::3,:,:] * anchors else: if(self.l>0): px = (self.l*self.sig(resOutput[:,0::3,:,:])-self.w)* prbmap py = (self.l*self.sig(resOutput[:,1::3,:,:])-self.w)* prbmap else: px = resOutput[:,0::3,:,:]* prbmap py = resOutput[:,1::3,:,:]* prbmap yw = resOutput[:,2::3,:,:] * prbmap # bias px = torch.sum(torch.sum(px, dim = 3), dim=2).unsqueeze(dim=2) py = torch.sum(torch.sum(py, dim = 3), dim=2).unsqueeze(dim=2) yw = torch.sum(torch.sum(yw, dim = 3), dim=2) gridx = Variable(self.xylim, requires_grad = False) gridy = Variable(self.yxlim, requires_grad = False) if(self.use_groundTruth): xmap = anchors * gridx ymap = anchors * gridy else: xmap = prbmap * gridx ymap = prbmap * gridy aveGirdx = torch.sum(torch.sum(xmap, dim=3), dim=2).unsqueeze(dim=2) aveGirdy = torch.sum(torch.sum(ymap, dim=3), dim=2).unsqueeze(dim=2) # local origin lo = torch.cat((aveGirdx, aveGirdy), dim=2)*1.0-10.0 opState = torch.cat((px, py), dim=2) + lo cosyaw = torch.cos(yw).unsqueeze(dim=2) sinyaw = torch.sin(yw).unsqueeze(dim=2) rotOutput = torch.cat((cosyaw,sinyaw), dim=2) opState[:,0,:] = labelState[:,0,:] opState[:,-1,:] = labelState[:,-1,:] rotOutput[:,0,:] = labelRot[:,0,:] rotOutput[:,-1,:] = labelRot[:,-1,:] if(self.filter >=3): opState = filter(opState, self.filter) rotOutput = filter(rotOutput, self.filter) rotOutput = torch.nn.functional.normalize(rotOutput, dim=2) return opState, rotOutput, prbmap, result_2 if __name__ == "__main__": pt = 200 model = trajFCNet(pt_num=pt, filter=7).cuda().half() totalt = 0.0 count = 0 model.eval() input = torch.rand(1,4,200,200).cuda().half() labelState = torch.rand(1,pt,2).cuda().half() labelRot = torch.rand(1,pt,2).cuda().half() anchors = torch.rand(1,pt,20,20).cuda().half() out = model(input, labelState, labelRot, anchors) with torch.no_grad(): for i in range(200): torch.cuda.synchronize() start = time.time() out = model(input, labelState, labelRot, anchors) torch.cuda.synchronize() end = time.time() if i>=20: totalt += 1000.0*(end-start) count +=1 print("model time: ", totalt / count, " ms")