modify some outputs of tensorrt transfer

This commit is contained in:
TJU_Lu 2024-12-16 12:03:41 +08:00
parent a349d0ca13
commit 59364bef31
2 changed files with 19 additions and 9 deletions

View File

@ -133,8 +133,7 @@ class YopoNet:
return obs_norm return obs_norm
def callback_depth(self, data): def callback_depth(self, data):
max_dis = 20.0 min_dis, max_dis = 0.03, 20.0
min_dis = 0.03
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0) scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
try: try:

View File

@ -11,6 +11,7 @@ import argparse
import os import os
import numpy as np import numpy as np
import torch import torch
import time
from torch2trt import torch2trt from torch2trt import torch2trt
from flightgym import QuadrotorEnv_v1 from flightgym import QuadrotorEnv_v1
from ruamel.yaml import YAML, RoundTripDumper, dump from ruamel.yaml import YAML, RoundTripDumper, dump
@ -45,11 +46,11 @@ def parser():
parser.add_argument("--trial", type=int, default=1, help="trial number") parser.add_argument("--trial", type=int, default=1, help="trial number")
parser.add_argument("--epoch", type=int, default=0, help="epoch number") parser.add_argument("--epoch", type=int, default=0, help="epoch number")
parser.add_argument("--iter", type=int, default=0, help="iter number") parser.add_argument("--iter", type=int, default=0, help="iter number")
parser.add_argument("--dir", type=str, default='yopo_trt.pth', help="output file name") parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name")
return parser return parser
def main(): if __name__ == "__main__":
args = parser().parse_args() args = parser().parse_args()
# load configurations # load configurations
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r')) cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
@ -79,23 +80,33 @@ def main():
lattice_primitive = saved_variables["data"]["lattice_primitive"] lattice_primitive = saved_variables["data"]["lattice_primitive"]
# The inputs should be consistent with training # The inputs should be consistent with training
print("TensorRT Transfer...")
depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32) depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32)
obs = np.zeros(shape=[1, 9], dtype=np.float32) obs = np.zeros(shape=[1, 9], dtype=np.float32)
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive) obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
depth_in = torch.from_numpy(depth).cuda() depth_in = torch.from_numpy(depth).cuda()
obs_in = torch.from_numpy(obs_input).cuda() obs_in = torch.from_numpy(obs_input).cuda()
model_trt = torch2trt(model.policy, [depth_in, obs_in]) model_trt = torch2trt(model.policy, [depth_in, obs_in])
torch.save(model_trt.state_dict(), args.dir) torch.save(model_trt.state_dict(), args.filename)
print("TensorRT Transfer Finish!")
# from torch2trt import TRTModule # from torch2trt import TRTModule
# model_trt = TRTModule() # model_trt = TRTModule()
# model_trt.load_state_dict(torch.load('yopo_trt.pth')) # model_trt.load_state_dict(torch.load('yopo_trt.pth'))
print("Evaluation...")
# warm up...
y_trt = model_trt(depth_in, obs_in) y_trt = model_trt(depth_in, obs_in)
y = model.policy(depth_in, obs_in) y = model.policy(depth_in, obs_in)
torch_start = time.time()
y = model.policy(depth_in, obs_in)
torch_end = time.time()
y_trt = model_trt(depth_in, obs_in)
trt_end = time.time()
error = torch.mean(torch.abs(y - y_trt)) error = torch.mean(torch.abs(y - y_trt))
print("transfer error: ", error) print("Torch Latency: ", 1000 * (torch_end - torch_start),
"ms, TensorRT Latency: ", 1000 * (trt_end - torch_end),
"ms, Transfer Error: ", error.item())
if __name__ == "__main__":
main()