From 59364bef312efb4c142d6ab765c20988e5a785cb Mon Sep 17 00:00:00 2001 From: TJU_Lu Date: Mon, 16 Dec 2024 12:03:41 +0800 Subject: [PATCH] modify some outputs of tensorrt transfer --- run/test_yopo_ros.py | 3 +-- run/yopo_trt_transfer.py | 25 ++++++++++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/run/test_yopo_ros.py b/run/test_yopo_ros.py index d5cf2be..68cca4f 100644 --- a/run/test_yopo_ros.py +++ b/run/test_yopo_ros.py @@ -133,8 +133,7 @@ class YopoNet: return obs_norm def callback_depth(self, data): - max_dis = 20.0 - min_dis = 0.03 + min_dis, max_dis = 0.03, 20.0 scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0) try: diff --git a/run/yopo_trt_transfer.py b/run/yopo_trt_transfer.py index e870f4f..1f464ad 100644 --- a/run/yopo_trt_transfer.py +++ b/run/yopo_trt_transfer.py @@ -11,6 +11,7 @@ import argparse import os import numpy as np import torch +import time from torch2trt import torch2trt from flightgym import QuadrotorEnv_v1 from ruamel.yaml import YAML, RoundTripDumper, dump @@ -45,11 +46,11 @@ def parser(): parser.add_argument("--trial", type=int, default=1, help="trial number") parser.add_argument("--epoch", type=int, default=0, help="epoch number") parser.add_argument("--iter", type=int, default=0, help="iter number") - parser.add_argument("--dir", type=str, default='yopo_trt.pth', help="output file name") + parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name") return parser -def main(): +if __name__ == "__main__": args = parser().parse_args() # load configurations cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r')) @@ -79,23 +80,33 @@ def main(): lattice_primitive = saved_variables["data"]["lattice_primitive"] # The inputs should be consistent with training + print("TensorRT Transfer...") depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32) obs = np.zeros(shape=[1, 9], dtype=np.float32) obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive) depth_in = torch.from_numpy(depth).cuda() obs_in = torch.from_numpy(obs_input).cuda() model_trt = torch2trt(model.policy, [depth_in, obs_in]) - torch.save(model_trt.state_dict(), args.dir) + torch.save(model_trt.state_dict(), args.filename) + print("TensorRT Transfer Finish!") # from torch2trt import TRTModule # model_trt = TRTModule() # model_trt.load_state_dict(torch.load('yopo_trt.pth')) + print("Evaluation...") + # warm up... y_trt = model_trt(depth_in, obs_in) y = model.policy(depth_in, obs_in) + + torch_start = time.time() + y = model.policy(depth_in, obs_in) + torch_end = time.time() + y_trt = model_trt(depth_in, obs_in) + trt_end = time.time() + error = torch.mean(torch.abs(y - y_trt)) - print("transfer error: ", error) + print("Torch Latency: ", 1000 * (torch_end - torch_start), + "ms, TensorRT Latency: ", 1000 * (trt_end - torch_end), + "ms, Transfer Error: ", error.item()) - -if __name__ == "__main__": - main()