252 lines
8.6 KiB
Python
252 lines
8.6 KiB
Python
|
import os
|
||
|
import pickle
|
||
|
import time
|
||
|
|
||
|
import torch
|
||
|
import yaml
|
||
|
from matplotlib import pyplot as plt
|
||
|
|
||
|
from experiment_launcher import single_experiment_yaml, run_experiment
|
||
|
from experiment_launcher.utils import fix_random_seed
|
||
|
from mp_baselines.planners.gpmp2 import GPMP2
|
||
|
from mp_baselines.planners.hybrid_planner import HybridPlanner
|
||
|
from mp_baselines.planners.multi_sample_based_planner import MultiSampleBasedPlanner
|
||
|
from mp_baselines.planners.rrt_connect import RRTConnect
|
||
|
from torch_robotics import environments, robots
|
||
|
from torch_robotics.tasks.tasks import PlanningTask
|
||
|
from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer
|
||
|
|
||
|
|
||
|
def generate_collision_free_trajectories(
|
||
|
env_id,
|
||
|
robot_id,
|
||
|
num_trajectories_per_context,
|
||
|
results_dir,
|
||
|
threshold_start_goal_pos=1.0,
|
||
|
obstacle_cutoff_margin=0.03,
|
||
|
n_tries=1000,
|
||
|
rrt_max_time=300,
|
||
|
gpmp_opt_iters=500,
|
||
|
n_support_points=64,
|
||
|
duration=5.0,
|
||
|
tensor_args=None,
|
||
|
debug=False,
|
||
|
):
|
||
|
# -------------------------------- Load env, robot, task ---------------------------------
|
||
|
# Environment
|
||
|
env_class = getattr(environments, env_id)
|
||
|
env = env_class(tensor_args=tensor_args)
|
||
|
|
||
|
# Robot
|
||
|
robot_class = getattr(robots, robot_id)
|
||
|
robot = robot_class(tensor_args=tensor_args)
|
||
|
|
||
|
# Task
|
||
|
task = PlanningTask(
|
||
|
env=env,
|
||
|
robot=robot,
|
||
|
obstacle_cutoff_margin=obstacle_cutoff_margin,
|
||
|
tensor_args=tensor_args
|
||
|
)
|
||
|
|
||
|
# -------------------------------- Start, Goal states ---------------------------------
|
||
|
start_state_pos, goal_state_pos = None, None
|
||
|
for _ in range(n_tries):
|
||
|
q_free = task.random_coll_free_q(n_samples=2)
|
||
|
start_state_pos = q_free[0]
|
||
|
goal_state_pos = q_free[1]
|
||
|
|
||
|
if torch.linalg.norm(start_state_pos - goal_state_pos) > threshold_start_goal_pos:
|
||
|
break
|
||
|
|
||
|
if start_state_pos is None or goal_state_pos is None:
|
||
|
raise ValueError(f"No collision free configuration was found\n"
|
||
|
f"start_state_pos: {start_state_pos}\n"
|
||
|
f"goal_state_pos: {goal_state_pos}\n")
|
||
|
|
||
|
n_trajectories = num_trajectories_per_context
|
||
|
|
||
|
# -------------------------------- Hybrid Planner ---------------------------------
|
||
|
# Sample-based planner
|
||
|
rrt_connect_default_params_env = env.get_rrt_connect_params(robot=robot)
|
||
|
rrt_connect_default_params_env['max_time'] = rrt_max_time
|
||
|
|
||
|
rrt_connect_params = dict(
|
||
|
**rrt_connect_default_params_env,
|
||
|
task=task,
|
||
|
start_state_pos=start_state_pos,
|
||
|
goal_state_pos=goal_state_pos,
|
||
|
tensor_args=tensor_args,
|
||
|
)
|
||
|
sample_based_planner_base = RRTConnect(**rrt_connect_params)
|
||
|
# sample_based_planner_base = RRTStar(**rrt_connect_params)
|
||
|
# sample_based_planner = sample_based_planner_base
|
||
|
sample_based_planner = MultiSampleBasedPlanner(
|
||
|
sample_based_planner_base,
|
||
|
n_trajectories=n_trajectories,
|
||
|
max_processes=-1,
|
||
|
optimize_sequentially=True
|
||
|
)
|
||
|
|
||
|
# Optimization-based planner
|
||
|
gpmp_default_params_env = env.get_gpmp2_params(robot=robot)
|
||
|
gpmp_default_params_env['opt_iters'] = gpmp_opt_iters
|
||
|
gpmp_default_params_env['n_support_points'] = n_support_points
|
||
|
gpmp_default_params_env['dt'] = duration / n_support_points
|
||
|
|
||
|
planner_params = dict(
|
||
|
**gpmp_default_params_env,
|
||
|
robot=robot,
|
||
|
n_dof=robot.q_dim,
|
||
|
num_particles_per_goal=n_trajectories,
|
||
|
start_state=start_state_pos,
|
||
|
multi_goal_states=goal_state_pos.unsqueeze(0), # add batch dim for interface,
|
||
|
collision_fields=task.get_collision_fields(),
|
||
|
tensor_args=tensor_args,
|
||
|
)
|
||
|
opt_based_planner = GPMP2(**planner_params)
|
||
|
|
||
|
###############
|
||
|
# Hybrid planner
|
||
|
planner = HybridPlanner(
|
||
|
sample_based_planner,
|
||
|
opt_based_planner,
|
||
|
tensor_args=tensor_args
|
||
|
)
|
||
|
|
||
|
# Optimize
|
||
|
trajs_iters = planner.optimize(debug=debug, print_times=True, return_iterations=True)
|
||
|
trajs_last_iter = trajs_iters[-1]
|
||
|
|
||
|
# -------------------------------- Save trajectories ---------------------------------
|
||
|
print(f'----------------STATISTICS----------------')
|
||
|
print(f'percentage free trajs: {task.compute_fraction_free_trajs(trajs_last_iter)*100:.2f}')
|
||
|
print(f'percentage collision intensity {task.compute_collision_intensity_trajs(trajs_last_iter)*100:.2f}')
|
||
|
print(f'success {task.compute_success_free_trajs(trajs_last_iter)}')
|
||
|
|
||
|
# save
|
||
|
torch.cuda.empty_cache()
|
||
|
trajs_last_iter_coll, trajs_last_iter_free = task.get_trajs_collision_and_free(trajs_last_iter)
|
||
|
if trajs_last_iter_coll is None:
|
||
|
trajs_last_iter_coll = torch.empty(0)
|
||
|
torch.save(trajs_last_iter_coll, os.path.join(results_dir, f'trajs-collision.pt'))
|
||
|
if trajs_last_iter_free is None:
|
||
|
trajs_last_iter_free = torch.empty(0)
|
||
|
torch.save(trajs_last_iter_free, os.path.join(results_dir, f'trajs-free.pt'))
|
||
|
|
||
|
# save results data dict
|
||
|
trajs_iters_coll, trajs_iters_free = task.get_trajs_collision_and_free(trajs_iters[-1])
|
||
|
results_data_dict = {
|
||
|
'duration': duration,
|
||
|
'n_support_points': n_support_points,
|
||
|
'dt': planner_params['dt'],
|
||
|
'trajs_iters_coll': trajs_iters_coll.unsqueeze(0) if trajs_iters_coll is not None else None,
|
||
|
'trajs_iters_free': trajs_iters_free.unsqueeze(0) if trajs_iters_free is not None else None,
|
||
|
}
|
||
|
|
||
|
with open(os.path.join(results_dir, f'results_data_dict.pickle'), 'wb') as handle:
|
||
|
pickle.dump(results_data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||
|
|
||
|
# -------------------------------- Visualize ---------------------------------
|
||
|
planner_visualizer = PlanningVisualizer(task=task)
|
||
|
|
||
|
trajs = trajs_last_iter_free
|
||
|
pos_trajs = robot.get_position(trajs)
|
||
|
start_state_pos = pos_trajs[0][0]
|
||
|
goal_state_pos = pos_trajs[0][-1]
|
||
|
|
||
|
fig, axs = planner_visualizer.plot_joint_space_state_trajectories(
|
||
|
trajs=trajs,
|
||
|
pos_start_state=start_state_pos, pos_goal_state=goal_state_pos,
|
||
|
vel_start_state=torch.zeros_like(start_state_pos), vel_goal_state=torch.zeros_like(goal_state_pos),
|
||
|
)
|
||
|
|
||
|
# save figure
|
||
|
fig.savefig(os.path.join(results_dir, f'trajectories.png'), dpi=300)
|
||
|
plt.close(fig)
|
||
|
|
||
|
num_trajectories_coll, num_trajectories_free = len(trajs_last_iter_coll), len(trajs_last_iter_free)
|
||
|
return num_trajectories_coll, num_trajectories_free
|
||
|
|
||
|
|
||
|
@single_experiment_yaml
|
||
|
def experiment(
|
||
|
# env_id: str = 'EnvDense2D',
|
||
|
# env_id: str = 'EnvSimple2D',
|
||
|
# env_id: str = 'EnvNarrowPassageDense2D',
|
||
|
env_id: str = 'EnvSpheres3D',
|
||
|
|
||
|
# robot_id: str = 'RobotPointMass',
|
||
|
robot_id: str = 'RobotPanda',
|
||
|
|
||
|
n_support_points: int = 64,
|
||
|
duration: float = 5.0, # seconds
|
||
|
|
||
|
# threshold_start_goal_pos: float = 1.0,
|
||
|
threshold_start_goal_pos: float = 1.83,
|
||
|
|
||
|
obstacle_cutoff_margin: float = 0.05,
|
||
|
|
||
|
num_trajectories: int = 5,
|
||
|
|
||
|
# device: str = 'cpu',
|
||
|
device: str = 'cuda',
|
||
|
|
||
|
debug: bool = True,
|
||
|
|
||
|
#######################################
|
||
|
# MANDATORY
|
||
|
seed: int = int(time.time()),
|
||
|
# seed: int = 0,
|
||
|
# seed: int = 1679258088,
|
||
|
results_dir: str = f"data",
|
||
|
|
||
|
#######################################
|
||
|
**kwargs
|
||
|
):
|
||
|
if debug:
|
||
|
fix_random_seed(seed)
|
||
|
|
||
|
print(f'\n\n-------------------- Generating data --------------------')
|
||
|
print(f'Seed: {seed}')
|
||
|
print(f'Env: {env_id}')
|
||
|
print(f'Robot: {robot_id}')
|
||
|
print(f'num_trajectories: {num_trajectories}')
|
||
|
|
||
|
####################################################################################################################
|
||
|
tensor_args = {'device': device, 'dtype': torch.float32}
|
||
|
|
||
|
metadata = {
|
||
|
'env_id': env_id,
|
||
|
'robot_id': robot_id,
|
||
|
'num_trajectories': num_trajectories
|
||
|
}
|
||
|
with open(os.path.join(results_dir, 'metadata.yaml'), 'w') as f:
|
||
|
yaml.dump(metadata, f, Dumper=yaml.Dumper)
|
||
|
|
||
|
# Generate trajectories
|
||
|
num_trajectories_coll, num_trajectories_free = generate_collision_free_trajectories(
|
||
|
env_id,
|
||
|
robot_id,
|
||
|
num_trajectories,
|
||
|
results_dir,
|
||
|
threshold_start_goal_pos=threshold_start_goal_pos,
|
||
|
obstacle_cutoff_margin=obstacle_cutoff_margin,
|
||
|
n_support_points=n_support_points,
|
||
|
duration=duration,
|
||
|
tensor_args=tensor_args,
|
||
|
debug=debug,
|
||
|
)
|
||
|
|
||
|
metadata.update(
|
||
|
num_trajectories_generated=num_trajectories_coll + num_trajectories_free,
|
||
|
num_trajectories_generated_coll=num_trajectories_coll,
|
||
|
num_trajectories_generated_free=num_trajectories_free,
|
||
|
)
|
||
|
with open(os.path.join(results_dir, 'metadata.yaml'), 'w') as f:
|
||
|
yaml.dump(metadata, f, Dumper=yaml.Dumper)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
run_experiment(experiment)
|