DPtraj/deepPathPlan/PathNet/visualizer_tojit.py
2025-08-07 11:13:12 +08:00

136 lines
4.7 KiB
Python
Executable File

import matplotlib.pyplot as plt
import numpy as np
import argparse
import torch
from resshep import reedshep_process
import numpy as np
from network import trajFCNet
from data_loader import geom2pix,get_encoder_input,geom2pix
import cv2
import os
def main(args):
mpnet = trajFCNet(4,200,7,l=1.2,use_groundTruth=False)
model_path='model.pkl'
mpnet.load_state_dict(torch.load("./models/"+model_path))
if torch.cuda.is_available():
mpnet.cuda()
mpnet.eval()
current_dir = os.path.dirname(os.path.abspath(__file__))
filepath = os.path.join(os.path.dirname(os.path.dirname(current_dir)), "totalData/ruins1")
for eidx in range(29000,30001):
if(not os.path.exists(filepath+'/e'+str(eidx)+'/path1.dat')):
continue
temp=np.fromfile(filepath+'/obcs/obc'+str(eidx)+'.dat')
env=temp.reshape(15000,2)
path = np.fromfile(filepath+'/e'+str(eidx)+'/path1.dat')
plt.figure(figsize=(19.2, 10.8))
plt.scatter(env[:,0], env[:,1], c='black', marker='o', label="ground truth")
ax = plt.gca()
raw_path = np.zeros([200,2], float)
raw_theta = np.zeros(200, float)
raw_t = np.zeros(200, float)
raw_v = np.zeros(200, float)
raw_c = np.zeros(200, float)
for j in range(200):
raw_t[j] = path[10+j*6]
raw_path[j,0] = path[10+j*6+1]
raw_path[j,1] = path[10+j*6+2]
raw_theta[j] = path[10+j*6+3]
raw_v[j] = path[10+j*6+4]
raw_c[j] = path[10+j*6+5]
# plt.quiver(raw_path[:,0], raw_path[:,1], 0.1 * np.cos(raw_theta), 0.1*np.sin(raw_theta),color='b',width=0.001, scale=5.0)
plt.scatter(raw_path[:,0], raw_path[:,1], c='blue', marker='o',s=5, label="ground truth")
# plt.plot(raw_path[:,0], raw_path[:,1],color='blue', marker='o', linestyle=':', linewidth=2, markersize=1)
fs = cv2.FileStorage(filepath+'/esdfmaps/'+str(eidx)+'.xml', cv2.FILE_STORAGE_READ)
fn = fs.getNode("instance")
image = fn.mat()
raw_env = image#H*W
raw_env = np.expand_dims(raw_env, 0)#H*W
raw_env = np.where(raw_env > 0.2, 1, 0)
# #free is 1, obs = 0
path = np.fromfile(filepath+'/e'+str(eidx)+'/path1.dat')
input = path[:10].reshape(2,5)
label = path[10:].reshape(200, 6)
env = np.expand_dims(image, 0)#H*W
goalpos = geom2pix(input[1][0:2])
startpos = geom2pix(input[0][0:2])
data = get_encoder_input(env, goalpos, input[1][2], startpos, input[0][2])
label_opDir = np.zeros((200,2))
label_opDir[:,0] = -np.sin(label[:, 3])
label_opDir[:,1] = np.cos(label[:, 3])
label_opState = label[:, 1:3]
label_Rot = np.zeros((200,2))
label_Rot[:,0] = np.cos(label[:, 3])
label_Rot[:,1] = np.sin(label[:, 3])
label_grid = np.floor((label_opState+10.0)/1.0).astype(int)#200*2
labelanchors = np.zeros((200,20,20))
index = [i for i in range(200)]
labelanchors[index, label_grid[:,0], label_grid[:,1]] = 1
data = torch.as_tensor(data.copy()).float().unsqueeze(0).contiguous().cuda()
label_opState = torch.as_tensor(label_opState.copy()).float().contiguous().unsqueeze(0).cuda()
label_Rot = torch.as_tensor(label_Rot.copy()).float().unsqueeze(0).contiguous().cuda()
labelanchors = torch.as_tensor(labelanchors.copy()).float().unsqueeze(0).contiguous().cuda()
#save model
mpnet.half()
data = data.half()
label_opState = label_opState.half()
label_Rot = label_Rot.half()
labelanchors = labelanchors.half()
opState, opRot,anchors,_= mpnet(data, label_opState,label_Rot, labelanchors)
CAB_traced_script_module = torch.jit.trace(mpnet, (data, label_opState,label_Rot, labelanchors))
CAB_traced_script_module.save("./models/model.pt")
opState=opState.data.cpu().numpy()[0]
opRot = opRot.data.cpu().numpy()[0]
env = env[0]
opState, opRot = reedshep_process(opState, opRot, env)
for i in range(opState.shape[0]):
plt.arrow(opState[i,0], opState[i,1],0.05*opRot[i,0] ,0.05*opRot[i,1], width=0.001, color='red')
plt.plot(opState[:,0], opState[:,1],color='green', marker='o', linestyle=':', linewidth=2, markersize=1)
plt.scatter(opState[:,0], opState[:,1], c='red', marker='o',s=5, label="planner")
plt.xlim((-11, 11))
plt.ylim((-11, 11))
plt.show()
plt.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path_file','-p', type=str, default='./data/freeEnv/path1002.dat')
args = parser.parse_args()
print(args)
main(args)