72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
# The backbone and the custom gradient layer.
|
|
import time
|
|
import torch as th
|
|
import torch.nn
|
|
import numpy as np
|
|
from torchvision.models import mobilenet_v3_small
|
|
from flightpolicy.yopo.resnet import resnet18
|
|
from torch.autograd import Function
|
|
|
|
|
|
# 18ms, Fast and effective.
|
|
class ResNet18(torch.nn.Module):
|
|
def __init__(self, output_dim: int, primitive_shape: int):
|
|
super(ResNet18, self).__init__()
|
|
self.cnn = resnet18(pretrained=False)
|
|
self.cnn.conv1 = th.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
if (primitive_shape != 1):
|
|
self.cnn.avgpool = th.nn.Sequential()
|
|
self.cnn.fc = th.nn.Conv2d(512, output_dim, kernel_size=1, stride=1, padding=0, bias=False)
|
|
self.features_dim = output_dim
|
|
|
|
def forward(self, depth: th.Tensor) -> th.Tensor:
|
|
return self.cnn(depth)
|
|
|
|
|
|
# 20ms, Performs worse than ResNet and is slower than ResNet-18.
|
|
class MobileNet(th.nn.Module):
|
|
def __init__(self, output_dim: int):
|
|
super(MobileNet, self).__init__()
|
|
self.cnn = mobilenet_v3_small(pretrained=False)
|
|
self.cnn.features[0][0] = th.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.cnn.classifier = th.nn.Linear(576, output_dim)
|
|
self.features_dim = output_dim
|
|
|
|
def forward(self, depth: th.Tensor) -> th.Tensor:
|
|
return self.cnn(depth)
|
|
|
|
|
|
def YopoBackbone(output_dim, primitive_shape):
|
|
return ResNet18(output_dim, primitive_shape)
|
|
|
|
|
|
class CostAndGradLayer(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_dp, train_env, primitive_id):
|
|
# print("input ", input_dp.shape)
|
|
device = input_dp.device
|
|
cost, grad = train_env.getCostAndGradient(input_dp, primitive_id)
|
|
grad = np.minimum(grad, 1.0) # Gradient clipping: Prevent excessively large values.
|
|
cost = torch.tensor(cost).to(device)
|
|
grad = torch.tensor(grad).to(device)
|
|
ctx.save_for_backward(grad)
|
|
cost.requires_grad = True
|
|
return cost
|
|
|
|
@staticmethod
|
|
def backward(ctx, cost_grad_input):
|
|
grad, = ctx.saved_tensors
|
|
return_grad = th.bmm(grad.unsqueeze(-1), cost_grad_input.unsqueeze(-1)).squeeze(dim=2)
|
|
# print("grad ", return_grad.shape)
|
|
# print("grad: ", return_grad)
|
|
return return_grad, None, None
|
|
|
|
|
|
if __name__ == '__main__':
|
|
net = YopoBackbone(64, 3)
|
|
input_ = torch.zeros((1, 1, 96, 96))
|
|
start = time.time()
|
|
output = net(input_)
|
|
print(time.time() - start)
|