85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
|
import pickle
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
import os.path
|
||
|
|
||
|
import torch
|
||
|
import yaml
|
||
|
|
||
|
from mpd.utils.loading import load_params_from_yaml
|
||
|
from torch_robotics import environments, robots
|
||
|
from torch_robotics.tasks.tasks import PlanningTask
|
||
|
from torch_robotics.torch_utils.torch_utils import DEFAULT_TENSOR_ARGS, to_torch
|
||
|
from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer
|
||
|
|
||
|
DATA_DIR = '../../data_trajectories/EnvSpheres3D-RobotPanda-cluster/66'
|
||
|
|
||
|
tensor_args = DEFAULT_TENSOR_ARGS
|
||
|
|
||
|
args = load_params_from_yaml(os.path.join(DATA_DIR, 'args.yaml'))
|
||
|
|
||
|
metadata = load_params_from_yaml(os.path.join(DATA_DIR, 'metadata.yaml'))
|
||
|
print(f"\n-------------- METADATA --------------")
|
||
|
print(yaml.dump(metadata))
|
||
|
print(f"\n--------------------------------------")
|
||
|
print()
|
||
|
|
||
|
# -------------------------------- Load env, robot, task ---------------------------------
|
||
|
# Environment
|
||
|
env_class = getattr(environments, args['env_id'])
|
||
|
env = env_class(tensor_args=tensor_args)
|
||
|
|
||
|
# Robot
|
||
|
robot_class = getattr(robots, args['robot_id'])
|
||
|
robot = robot_class(
|
||
|
tensor_args=tensor_args
|
||
|
)
|
||
|
|
||
|
# Task
|
||
|
task = PlanningTask(
|
||
|
env=env,
|
||
|
robot=robot,
|
||
|
obstacle_cutoff_margin=args['obstacle_cutoff_margin'],
|
||
|
tensor_args=tensor_args
|
||
|
)
|
||
|
|
||
|
# -------------------------------- Load trajectories -------------------------
|
||
|
trajs_collision = torch.load(os.path.join(DATA_DIR, 'trajs-collision.pt')).to(**tensor_args)
|
||
|
trajs_free = torch.load(os.path.join(DATA_DIR, 'trajs-free.pt')).to(**tensor_args)
|
||
|
|
||
|
# trajs = torch.cat((trajs_collision, trajs_free))
|
||
|
trajs = trajs_free
|
||
|
|
||
|
# -------------------------------- Visualize ---------------------------------
|
||
|
planner_visualizer = PlanningVisualizer(task=task)
|
||
|
|
||
|
pos_trajs = robot.get_position(trajs)
|
||
|
start_state_pos = pos_trajs[0][0]
|
||
|
goal_state_pos = pos_trajs[0][-1]
|
||
|
|
||
|
planner_visualizer.plot_joint_space_state_trajectories(
|
||
|
trajs=trajs,
|
||
|
pos_start_state=start_state_pos, pos_goal_state=goal_state_pos,
|
||
|
vel_start_state=torch.zeros_like(start_state_pos), vel_goal_state=torch.zeros_like(goal_state_pos),
|
||
|
)
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
planner_visualizer.render_robot_trajectories(
|
||
|
trajs=trajs, start_state=start_state_pos, goal_state=goal_state_pos,
|
||
|
render_planner=False,
|
||
|
)
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
planner_visualizer.animate_robot_trajectories(
|
||
|
trajs=trajs, start_state=start_state_pos, goal_state=goal_state_pos,
|
||
|
plot_trajs=True,
|
||
|
video_filepath=os.path.join(DATA_DIR, 'robot-traj.mp4'),
|
||
|
# n_frames=max((2, pos_trajs_iters[-1].shape[1]//10)),
|
||
|
n_frames=trajs.shape[1],
|
||
|
anim_time=args['duration']
|
||
|
)
|
||
|
|