modify some outputs of tensorrt transfer

This commit is contained in:
TJU_Lu 2024-12-16 12:03:41 +08:00
parent a349d0ca13
commit 59364bef31
2 changed files with 19 additions and 9 deletions

View File

@ -133,8 +133,7 @@ class YopoNet:
return obs_norm
def callback_depth(self, data):
max_dis = 20.0
min_dis = 0.03
min_dis, max_dis = 0.03, 20.0
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
try:

View File

@ -11,6 +11,7 @@ import argparse
import os
import numpy as np
import torch
import time
from torch2trt import torch2trt
from flightgym import QuadrotorEnv_v1
from ruamel.yaml import YAML, RoundTripDumper, dump
@ -45,11 +46,11 @@ def parser():
parser.add_argument("--trial", type=int, default=1, help="trial number")
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
parser.add_argument("--iter", type=int, default=0, help="iter number")
parser.add_argument("--dir", type=str, default='yopo_trt.pth', help="output file name")
parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name")
return parser
def main():
if __name__ == "__main__":
args = parser().parse_args()
# load configurations
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
@ -79,23 +80,33 @@ def main():
lattice_primitive = saved_variables["data"]["lattice_primitive"]
# The inputs should be consistent with training
print("TensorRT Transfer...")
depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32)
obs = np.zeros(shape=[1, 9], dtype=np.float32)
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
depth_in = torch.from_numpy(depth).cuda()
obs_in = torch.from_numpy(obs_input).cuda()
model_trt = torch2trt(model.policy, [depth_in, obs_in])
torch.save(model_trt.state_dict(), args.dir)
torch.save(model_trt.state_dict(), args.filename)
print("TensorRT Transfer Finish!")
# from torch2trt import TRTModule
# model_trt = TRTModule()
# model_trt.load_state_dict(torch.load('yopo_trt.pth'))
print("Evaluation...")
# warm up...
y_trt = model_trt(depth_in, obs_in)
y = model.policy(depth_in, obs_in)
torch_start = time.time()
y = model.policy(depth_in, obs_in)
torch_end = time.time()
y_trt = model_trt(depth_in, obs_in)
trt_end = time.time()
error = torch.mean(torch.abs(y - y_trt))
print("transfer error: ", error)
print("Torch Latency: ", 1000 * (torch_end - torch_start),
"ms, TensorRT Latency: ", 1000 * (trt_end - torch_end),
"ms, Transfer Error: ", error.item())
if __name__ == "__main__":
main()