modify some outputs of tensorrt transfer
This commit is contained in:
parent
a349d0ca13
commit
59364bef31
@ -133,8 +133,7 @@ class YopoNet:
|
||||
return obs_norm
|
||||
|
||||
def callback_depth(self, data):
|
||||
max_dis = 20.0
|
||||
min_dis = 0.03
|
||||
min_dis, max_dis = 0.03, 20.0
|
||||
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
|
||||
|
||||
try:
|
||||
|
||||
@ -11,6 +11,7 @@ import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import time
|
||||
from torch2trt import torch2trt
|
||||
from flightgym import QuadrotorEnv_v1
|
||||
from ruamel.yaml import YAML, RoundTripDumper, dump
|
||||
@ -45,11 +46,11 @@ def parser():
|
||||
parser.add_argument("--trial", type=int, default=1, help="trial number")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
|
||||
parser.add_argument("--iter", type=int, default=0, help="iter number")
|
||||
parser.add_argument("--dir", type=str, default='yopo_trt.pth', help="output file name")
|
||||
parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
if __name__ == "__main__":
|
||||
args = parser().parse_args()
|
||||
# load configurations
|
||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
||||
@ -79,23 +80,33 @@ def main():
|
||||
lattice_primitive = saved_variables["data"]["lattice_primitive"]
|
||||
|
||||
# The inputs should be consistent with training
|
||||
print("TensorRT Transfer...")
|
||||
depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32)
|
||||
obs = np.zeros(shape=[1, 9], dtype=np.float32)
|
||||
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
|
||||
depth_in = torch.from_numpy(depth).cuda()
|
||||
obs_in = torch.from_numpy(obs_input).cuda()
|
||||
model_trt = torch2trt(model.policy, [depth_in, obs_in])
|
||||
torch.save(model_trt.state_dict(), args.dir)
|
||||
torch.save(model_trt.state_dict(), args.filename)
|
||||
print("TensorRT Transfer Finish!")
|
||||
|
||||
# from torch2trt import TRTModule
|
||||
# model_trt = TRTModule()
|
||||
# model_trt.load_state_dict(torch.load('yopo_trt.pth'))
|
||||
|
||||
print("Evaluation...")
|
||||
# warm up...
|
||||
y_trt = model_trt(depth_in, obs_in)
|
||||
y = model.policy(depth_in, obs_in)
|
||||
|
||||
torch_start = time.time()
|
||||
y = model.policy(depth_in, obs_in)
|
||||
torch_end = time.time()
|
||||
y_trt = model_trt(depth_in, obs_in)
|
||||
trt_end = time.time()
|
||||
|
||||
error = torch.mean(torch.abs(y - y_trt))
|
||||
print("transfer error: ", error)
|
||||
print("Torch Latency: ", 1000 * (torch_end - torch_start),
|
||||
"ms, TensorRT Latency: ", 1000 * (trt_end - torch_end),
|
||||
"ms, Transfer Error: ", error.item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user