YOPO/flightpolicy/yopo/yopo_network.py

72 lines
2.4 KiB
Python

# The backbone and the custom gradient layer.
import time
import torch as th
import torch.nn
import numpy as np
from torchvision.models import mobilenet_v3_small
from flightpolicy.yopo.resnet import resnet18
from torch.autograd import Function
# 18ms, Fast and effective.
class ResNet18(torch.nn.Module):
def __init__(self, output_dim: int, primitive_shape: int):
super(ResNet18, self).__init__()
self.cnn = resnet18(pretrained=False)
self.cnn.conv1 = th.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
if (primitive_shape != 1):
self.cnn.avgpool = th.nn.Sequential()
self.cnn.fc = th.nn.Conv2d(512, output_dim, kernel_size=1, stride=1, padding=0, bias=False)
self.features_dim = output_dim
def forward(self, depth: th.Tensor) -> th.Tensor:
return self.cnn(depth)
# 20ms, Performs worse than ResNet and is slower than ResNet-18.
class MobileNet(th.nn.Module):
def __init__(self, output_dim: int):
super(MobileNet, self).__init__()
self.cnn = mobilenet_v3_small(pretrained=False)
self.cnn.features[0][0] = th.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.cnn.classifier = th.nn.Linear(576, output_dim)
self.features_dim = output_dim
def forward(self, depth: th.Tensor) -> th.Tensor:
return self.cnn(depth)
def YopoBackbone(output_dim, primitive_shape):
return ResNet18(output_dim, primitive_shape)
class CostAndGradLayer(Function):
@staticmethod
def forward(ctx, input_dp, train_env, primitive_id):
# print("input ", input_dp.shape)
device = input_dp.device
cost, grad = train_env.getCostAndGradient(input_dp, primitive_id)
grad = np.minimum(grad, 1.0) # Gradient clipping: Prevent excessively large values.
cost = torch.tensor(cost).to(device)
grad = torch.tensor(grad).to(device)
ctx.save_for_backward(grad)
cost.requires_grad = True
return cost
@staticmethod
def backward(ctx, cost_grad_input):
grad, = ctx.saved_tensors
return_grad = th.bmm(grad.unsqueeze(-1), cost_grad_input.unsqueeze(-1)).squeeze(dim=2)
# print("grad ", return_grad.shape)
# print("grad: ", return_grad)
return return_grad, None, None
if __name__ == '__main__':
net = YopoBackbone(64, 3)
input_ = torch.zeros((1, 1, 96, 96))
start = time.time()
output = net(input_)
print(time.time() - start)