238 lines
9.0 KiB
Python
238 lines
9.0 KiB
Python
import abc
|
|
import os.path
|
|
|
|
import git
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from mpd.datasets.normalization import DatasetNormalizer
|
|
from mpd.utils.loading import load_params_from_yaml
|
|
from torch_robotics import environments, robots
|
|
from torch_robotics.environments import EnvDense2DExtraObjects
|
|
from torch_robotics.environments.env_simple_2d_extra_objects import EnvSimple2DExtraObjects
|
|
from torch_robotics.tasks.tasks import PlanningTask
|
|
from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer
|
|
|
|
repo = git.Repo('.', search_parent_directories=True)
|
|
dataset_base_dir = os.path.join(repo.working_dir, 'data_trajectories')
|
|
|
|
|
|
class TrajectoryDatasetBase(Dataset, abc.ABC):
|
|
|
|
def __init__(self,
|
|
dataset_subdir=None,
|
|
include_velocity=False,
|
|
normalizer='LimitsNormalizer',
|
|
use_extra_objects=False,
|
|
obstacle_cutoff_margin=None,
|
|
tensor_args=None,
|
|
**kwargs):
|
|
|
|
self.tensor_args = tensor_args
|
|
|
|
self.dataset_subdir = dataset_subdir
|
|
self.base_dir = os.path.join(dataset_base_dir, self.dataset_subdir)
|
|
|
|
self.args = load_params_from_yaml(os.path.join(self.base_dir, '0', 'args.yaml'))
|
|
self.metadata = load_params_from_yaml(os.path.join(self.base_dir, '0', 'metadata.yaml'))
|
|
|
|
if obstacle_cutoff_margin is not None:
|
|
self.args['obstacle_cutoff_margin'] = obstacle_cutoff_margin
|
|
|
|
# -------------------------------- Load env, robot, task ---------------------------------
|
|
# Environment
|
|
env_class = getattr(
|
|
environments, self.metadata['env_id'] + 'ExtraObjects' if use_extra_objects else self.metadata['env_id'])
|
|
self.env = env_class(tensor_args=tensor_args)
|
|
|
|
# Robot
|
|
robot_class = getattr(robots, self.metadata['robot_id'])
|
|
self.robot = robot_class(tensor_args=tensor_args)
|
|
|
|
# Task
|
|
self.task = PlanningTask(env=self.env, robot=self.robot, tensor_args=tensor_args, **self.args)
|
|
self.planner_visualizer = PlanningVisualizer(task=self.task)
|
|
|
|
# -------------------------------- Load trajectories ---------------------------------
|
|
self.threshold_start_goal_pos = self.args['threshold_start_goal_pos']
|
|
|
|
self.field_key_traj = 'traj'
|
|
self.field_key_task = 'task'
|
|
self.fields = {}
|
|
|
|
# load data
|
|
self.include_velocity = include_velocity
|
|
self.map_task_id_to_trajectories_id = {}
|
|
self.map_trajectory_id_to_task_id = {}
|
|
self.load_trajectories()
|
|
|
|
# dimensions
|
|
b, h, d = self.dataset_shape = self.fields[self.field_key_traj].shape
|
|
self.n_trajs = b
|
|
self.n_support_points = h
|
|
self.state_dim = d # state dimension used for the diffusion model
|
|
self.trajectory_dim = (self.n_support_points, d)
|
|
|
|
# normalize the data (for the diffusion model)
|
|
self.normalizer = DatasetNormalizer(self.fields, normalizer=normalizer)
|
|
self.normalizer_keys = [self.field_key_traj, self.field_key_task]
|
|
self.normalize_all_data(*self.normalizer_keys)
|
|
|
|
def load_trajectories(self):
|
|
# load free trajectories
|
|
trajs_free_l = []
|
|
task_id = 0
|
|
n_trajs = 0
|
|
for current_dir, subdirs, files in os.walk(self.base_dir, topdown=True):
|
|
if 'trajs-free.pt' in files:
|
|
trajs_free_tmp = torch.load(
|
|
os.path.join(current_dir, 'trajs-free.pt'), map_location=self.tensor_args['device'])
|
|
trajectories_idx = n_trajs + np.arange(len(trajs_free_tmp))
|
|
self.map_task_id_to_trajectories_id[task_id] = trajectories_idx
|
|
for j in trajectories_idx:
|
|
self.map_trajectory_id_to_task_id[j] = task_id
|
|
task_id += 1
|
|
n_trajs += len(trajs_free_tmp)
|
|
trajs_free_l.append(trajs_free_tmp)
|
|
|
|
trajs_free = torch.cat(trajs_free_l)
|
|
trajs_free_pos = self.robot.get_position(trajs_free)
|
|
|
|
if self.include_velocity:
|
|
trajs = trajs_free
|
|
else:
|
|
trajs = trajs_free_pos
|
|
self.fields[self.field_key_traj] = trajs
|
|
|
|
# task: start and goal state positions [n_trajectories, 2 * state_dim]
|
|
task = torch.cat((trajs_free_pos[..., 0, :], trajs_free_pos[..., -1, :]), dim=-1)
|
|
self.fields[self.field_key_task] = task
|
|
|
|
def normalize_all_data(self, *keys):
|
|
for key in keys:
|
|
self.fields[f'{key}_normalized'] = self.normalizer(self.fields[f'{key}'], key)
|
|
|
|
def render(self, task_id=3,
|
|
render_joint_trajectories=False,
|
|
render_robot_trajectories=False,
|
|
**kwargs):
|
|
# -------------------------------- Visualize ---------------------------------
|
|
idxs = self.map_task_id_to_trajectories_id[task_id]
|
|
pos_trajs = self.robot.get_position(self.fields[self.field_key_traj][idxs])
|
|
start_state_pos = pos_trajs[0][0]
|
|
goal_state_pos = pos_trajs[0][-1]
|
|
|
|
fig1, axs1, fig2, axs2 = [None] * 4
|
|
|
|
if render_joint_trajectories:
|
|
fig1, axs1 = self.planner_visualizer.plot_joint_space_state_trajectories(
|
|
trajs=pos_trajs,
|
|
pos_start_state=start_state_pos, pos_goal_state=goal_state_pos,
|
|
vel_start_state=torch.zeros_like(start_state_pos), vel_goal_state=torch.zeros_like(goal_state_pos),
|
|
)
|
|
|
|
if render_robot_trajectories:
|
|
fig2, axs2 = self.planner_visualizer.render_robot_trajectories(
|
|
trajs=pos_trajs, start_state=start_state_pos, goal_state=goal_state_pos,
|
|
)
|
|
|
|
return fig1, axs1, fig2, axs2
|
|
|
|
def __repr__(self):
|
|
msg = f'TrajectoryDataset\n' \
|
|
f'n_trajs: {self.n_trajs}\n' \
|
|
f'trajectory_dim: {self.trajectory_dim}\n'
|
|
return msg
|
|
|
|
def __len__(self):
|
|
return self.n_trajs
|
|
|
|
def __getitem__(self, index):
|
|
# Generates one sample of data - one trajectory and tasks
|
|
field_traj_normalized = f'{self.field_key_traj}_normalized'
|
|
field_task_normalized = f'{self.field_key_task}_normalized'
|
|
traj_normalized = self.fields[field_traj_normalized][index]
|
|
task_normalized = self.fields[field_task_normalized][index]
|
|
data = {
|
|
field_traj_normalized: traj_normalized,
|
|
field_task_normalized: task_normalized
|
|
}
|
|
|
|
# build hard conditions
|
|
hard_conds = self.get_hard_conditions(traj_normalized, horizon=len(traj_normalized))
|
|
data.update({'hard_conds': hard_conds})
|
|
|
|
return data
|
|
|
|
def get_hard_conditions(self, traj, horizon=None, normalize=False):
|
|
raise NotImplementedError
|
|
|
|
def get_unnormalized(self, index):
|
|
raise NotImplementedError
|
|
traj = self.fields[self.field_key_traj][index][..., :self.state_dim]
|
|
task = self.fields[self.field_key_task][index]
|
|
if not self.include_velocity:
|
|
task = task[self.task_idxs]
|
|
data = {self.field_key_traj: traj,
|
|
self.field_key_task: task,
|
|
}
|
|
if self.variable_environment:
|
|
data.update({self.field_key_env: self.fields[self.field_key_env][index]})
|
|
|
|
# hard conditions
|
|
# hard_conds = self.get_hard_conds(tasks)
|
|
hard_conds = self.get_hard_conditions(traj)
|
|
data.update({'hard_conds': hard_conds})
|
|
|
|
return data
|
|
|
|
def unnormalize(self, x, key):
|
|
return self.normalizer.unnormalize(x, key)
|
|
|
|
def normalize(self, x, key):
|
|
return self.normalizer.normalize(x, key)
|
|
|
|
def unnormalize_trajectories(self, x):
|
|
return self.unnormalize(x, self.field_key_traj)
|
|
|
|
def normalize_trajectories(self, x):
|
|
return self.normalize(x, self.field_key_traj)
|
|
|
|
def unnormalize_tasks(self, x):
|
|
return self.unnormalize(x, self.field_key_task)
|
|
|
|
def normalize_tasks(self, x):
|
|
return self.normalize(x, self.field_key_task)
|
|
|
|
|
|
class TrajectoryDataset(TrajectoryDatasetBase):
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def get_hard_conditions(self, traj, horizon=None, normalize=False):
|
|
# start and goal positions
|
|
start_state_pos = self.robot.get_position(traj[0])
|
|
goal_state_pos = self.robot.get_position(traj[-1])
|
|
|
|
if self.include_velocity:
|
|
# If velocities are part of the state, then set them to zero at the beggining and end of a trajectory
|
|
start_state = torch.cat((start_state_pos, torch.zeros_like(start_state_pos)), dim=-1)
|
|
goal_state = torch.cat((goal_state_pos, torch.zeros_like(goal_state_pos)), dim=-1)
|
|
else:
|
|
start_state = start_state_pos
|
|
goal_state = goal_state_pos
|
|
|
|
if normalize:
|
|
start_state = self.normalizer.normalize(start_state, key=self.field_key_traj)
|
|
goal_state = self.normalizer.normalize(goal_state, key=self.field_key_traj)
|
|
|
|
if horizon is None:
|
|
horizon = self.n_support_points
|
|
hard_conds = {
|
|
0: start_state,
|
|
horizon - 1: goal_state
|
|
}
|
|
return hard_conds
|