modify some outputs of tensorrt transfer
This commit is contained in:
parent
a349d0ca13
commit
59364bef31
@ -133,8 +133,7 @@ class YopoNet:
|
|||||||
return obs_norm
|
return obs_norm
|
||||||
|
|
||||||
def callback_depth(self, data):
|
def callback_depth(self, data):
|
||||||
max_dis = 20.0
|
min_dis, max_dis = 0.03, 20.0
|
||||||
min_dis = 0.03
|
|
||||||
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
|
scale = {'435': 0.001, 'flightmare': 1.0}.get(self.env, 1.0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import time
|
||||||
from torch2trt import torch2trt
|
from torch2trt import torch2trt
|
||||||
from flightgym import QuadrotorEnv_v1
|
from flightgym import QuadrotorEnv_v1
|
||||||
from ruamel.yaml import YAML, RoundTripDumper, dump
|
from ruamel.yaml import YAML, RoundTripDumper, dump
|
||||||
@ -45,11 +46,11 @@ def parser():
|
|||||||
parser.add_argument("--trial", type=int, default=1, help="trial number")
|
parser.add_argument("--trial", type=int, default=1, help="trial number")
|
||||||
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
|
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
|
||||||
parser.add_argument("--iter", type=int, default=0, help="iter number")
|
parser.add_argument("--iter", type=int, default=0, help="iter number")
|
||||||
parser.add_argument("--dir", type=str, default='yopo_trt.pth', help="output file name")
|
parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main():
|
if __name__ == "__main__":
|
||||||
args = parser().parse_args()
|
args = parser().parse_args()
|
||||||
# load configurations
|
# load configurations
|
||||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
||||||
@ -79,23 +80,33 @@ def main():
|
|||||||
lattice_primitive = saved_variables["data"]["lattice_primitive"]
|
lattice_primitive = saved_variables["data"]["lattice_primitive"]
|
||||||
|
|
||||||
# The inputs should be consistent with training
|
# The inputs should be consistent with training
|
||||||
|
print("TensorRT Transfer...")
|
||||||
depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32)
|
depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32)
|
||||||
obs = np.zeros(shape=[1, 9], dtype=np.float32)
|
obs = np.zeros(shape=[1, 9], dtype=np.float32)
|
||||||
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
|
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
|
||||||
depth_in = torch.from_numpy(depth).cuda()
|
depth_in = torch.from_numpy(depth).cuda()
|
||||||
obs_in = torch.from_numpy(obs_input).cuda()
|
obs_in = torch.from_numpy(obs_input).cuda()
|
||||||
model_trt = torch2trt(model.policy, [depth_in, obs_in])
|
model_trt = torch2trt(model.policy, [depth_in, obs_in])
|
||||||
torch.save(model_trt.state_dict(), args.dir)
|
torch.save(model_trt.state_dict(), args.filename)
|
||||||
|
print("TensorRT Transfer Finish!")
|
||||||
|
|
||||||
# from torch2trt import TRTModule
|
# from torch2trt import TRTModule
|
||||||
# model_trt = TRTModule()
|
# model_trt = TRTModule()
|
||||||
# model_trt.load_state_dict(torch.load('yopo_trt.pth'))
|
# model_trt.load_state_dict(torch.load('yopo_trt.pth'))
|
||||||
|
|
||||||
|
print("Evaluation...")
|
||||||
|
# warm up...
|
||||||
y_trt = model_trt(depth_in, obs_in)
|
y_trt = model_trt(depth_in, obs_in)
|
||||||
y = model.policy(depth_in, obs_in)
|
y = model.policy(depth_in, obs_in)
|
||||||
|
|
||||||
|
torch_start = time.time()
|
||||||
|
y = model.policy(depth_in, obs_in)
|
||||||
|
torch_end = time.time()
|
||||||
|
y_trt = model_trt(depth_in, obs_in)
|
||||||
|
trt_end = time.time()
|
||||||
|
|
||||||
error = torch.mean(torch.abs(y - y_trt))
|
error = torch.mean(torch.abs(y - y_trt))
|
||||||
print("transfer error: ", error)
|
print("Torch Latency: ", 1000 * (torch_end - torch_start),
|
||||||
|
"ms, TensorRT Latency: ", 1000 * (trt_end - torch_end),
|
||||||
|
"ms, Transfer Error: ", error.item())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user