mpd-public/scripts/generate_data/visualize_trajectories.py

85 lines
2.5 KiB
Python
Raw Normal View History

2023-10-23 15:45:14 +02:00
import pickle
import matplotlib.pyplot as plt
import os.path
import torch
import yaml
from mpd.utils.loading import load_params_from_yaml
from torch_robotics import environments, robots
from torch_robotics.tasks.tasks import PlanningTask
from torch_robotics.torch_utils.torch_utils import DEFAULT_TENSOR_ARGS, to_torch
from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer
DATA_DIR = '../../data_trajectories/EnvSpheres3D-RobotPanda-cluster/66'
tensor_args = DEFAULT_TENSOR_ARGS
args = load_params_from_yaml(os.path.join(DATA_DIR, 'args.yaml'))
metadata = load_params_from_yaml(os.path.join(DATA_DIR, 'metadata.yaml'))
print(f"\n-------------- METADATA --------------")
print(yaml.dump(metadata))
print(f"\n--------------------------------------")
print()
# -------------------------------- Load env, robot, task ---------------------------------
# Environment
env_class = getattr(environments, args['env_id'])
env = env_class(tensor_args=tensor_args)
# Robot
robot_class = getattr(robots, args['robot_id'])
robot = robot_class(
tensor_args=tensor_args
)
# Task
task = PlanningTask(
env=env,
robot=robot,
obstacle_cutoff_margin=args['obstacle_cutoff_margin'],
tensor_args=tensor_args
)
# -------------------------------- Load trajectories -------------------------
trajs_collision = torch.load(os.path.join(DATA_DIR, 'trajs-collision.pt')).to(**tensor_args)
trajs_free = torch.load(os.path.join(DATA_DIR, 'trajs-free.pt')).to(**tensor_args)
# trajs = torch.cat((trajs_collision, trajs_free))
trajs = trajs_free
# -------------------------------- Visualize ---------------------------------
planner_visualizer = PlanningVisualizer(task=task)
pos_trajs = robot.get_position(trajs)
start_state_pos = pos_trajs[0][0]
goal_state_pos = pos_trajs[0][-1]
planner_visualizer.plot_joint_space_state_trajectories(
trajs=trajs,
pos_start_state=start_state_pos, pos_goal_state=goal_state_pos,
vel_start_state=torch.zeros_like(start_state_pos), vel_goal_state=torch.zeros_like(goal_state_pos),
)
plt.show()
planner_visualizer.render_robot_trajectories(
trajs=trajs, start_state=start_state_pos, goal_state=goal_state_pos,
render_planner=False,
)
plt.show()
planner_visualizer.animate_robot_trajectories(
trajs=trajs, start_state=start_state_pos, goal_state=goal_state_pos,
plot_trajs=True,
video_filepath=os.path.join(DATA_DIR, 'robot-traj.mp4'),
# n_frames=max((2, pos_trajs_iters[-1].shape[1]//10)),
n_frames=trajs.shape[1],
anim_time=args['duration']
)