102 lines
3.7 KiB
Python
102 lines
3.7 KiB
Python
"""
|
|
将yopo模型转换为Tensorrt
|
|
prepare:
|
|
0. make sure you install already install TensorRT
|
|
1. pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com
|
|
2. git clone https://github.com/NVIDIA-AI-IOT/torch2trt
|
|
cd torch2trt
|
|
python setup.py install
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
import time
|
|
from torch2trt import torch2trt
|
|
from flightgym import QuadrotorEnv_v1
|
|
from ruamel.yaml import YAML, RoundTripDumper, dump
|
|
from flightpolicy.envs import vec_env_wrapper as wrapper
|
|
from flightpolicy.yopo.yopo_algorithm import YopoAlgorithm
|
|
|
|
|
|
def parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--trial", type=int, default=1, help="trial number")
|
|
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
|
|
parser.add_argument("--iter", type=int, default=0, help="iter number")
|
|
parser.add_argument("--fp16_mode", type=int, default=1, help="fp16 or fp32")
|
|
parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name")
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser().parse_args()
|
|
# load configurations
|
|
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
|
cfg["env"]["num_envs"] = 1
|
|
cfg["env"]["supervised"] = False
|
|
cfg["env"]["imitation"] = False
|
|
cfg["env"]["render"] = False
|
|
|
|
# create environment
|
|
train_env = QuadrotorEnv_v1(dump(cfg, Dumper=RoundTripDumper), False)
|
|
train_env = wrapper.FlightEnvVec(train_env)
|
|
model = YopoAlgorithm(env=train_env,
|
|
policy_kwargs=dict(
|
|
activation_fn=torch.nn.ReLU,
|
|
net_arch=[256, 256],
|
|
hidden_state=64
|
|
))
|
|
|
|
rsg_root = os.path.dirname(os.path.abspath(__file__))
|
|
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
|
|
device = torch.device("cuda")
|
|
saved_variables = torch.load(weight, map_location=device)
|
|
model.policy.load_state_dict(saved_variables["state_dict"], strict=False)
|
|
model.policy.set_training_mode(False)
|
|
torch.set_grad_enabled(False)
|
|
|
|
lattice_space = saved_variables["data"]["lattice_space"]
|
|
lattice_primitive = saved_variables["data"]["lattice_primitive"]
|
|
|
|
# The inputs should be consistent with training
|
|
print("TensorRT Transfer...")
|
|
depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32)
|
|
obs = np.zeros(shape=[1, 9, lattice_space.vertical_num, lattice_space.horizon_num], dtype=np.float32)
|
|
depth_in = torch.from_numpy(depth).cuda()
|
|
obs_in = torch.from_numpy(obs).cuda()
|
|
model_trt = torch2trt(model.policy, [depth_in, obs_in], fp16_mode=args.fp16_mode)
|
|
torch.save(model_trt.state_dict(), args.filename)
|
|
print("TensorRT Transfer Finish!")
|
|
|
|
# from torch2trt import TRTModule
|
|
# model_trt = TRTModule()
|
|
# model_trt.load_state_dict(torch.load('yopo_trt.pth'))
|
|
|
|
print("Evaluation...")
|
|
# Warm Up...
|
|
y_trt = model_trt(depth_in, obs_in)
|
|
y = model.policy(depth_in, obs_in)
|
|
torch.cuda.synchronize()
|
|
|
|
# PyTorch Latency
|
|
torch_start = time.time()
|
|
y = model.policy(depth_in, obs_in)
|
|
torch.cuda.synchronize()
|
|
torch_end = time.time()
|
|
|
|
# TensorRT Latency
|
|
trt_start = time.time()
|
|
y_trt = model_trt(depth_in, obs_in)
|
|
torch.cuda.synchronize()
|
|
trt_end = time.time()
|
|
|
|
# Transfer Error
|
|
error = torch.mean(torch.abs(y - y_trt))
|
|
|
|
print(f"Torch Latency: {1000 * (torch_end - torch_start):.3f} ms, "
|
|
f"TensorRT Latency: {1000 * (trt_end - trt_start):.3f} ms, "
|
|
f"Transfer Error: {error.item():.8f}")
|
|
|