mpd-public/mpd/datasets/trajectories.py

238 lines
9.0 KiB
Python
Raw Normal View History

2023-10-23 15:45:14 +02:00
import abc
import os.path
import git
import numpy as np
import torch
from torch.utils.data import Dataset
from mpd.datasets.normalization import DatasetNormalizer
from mpd.utils.loading import load_params_from_yaml
from torch_robotics import environments, robots
from torch_robotics.environments import EnvDense2DExtraObjects
from torch_robotics.environments.env_simple_2d_extra_objects import EnvSimple2DExtraObjects
from torch_robotics.tasks.tasks import PlanningTask
from torch_robotics.visualizers.planning_visualizer import PlanningVisualizer
repo = git.Repo('.', search_parent_directories=True)
dataset_base_dir = os.path.join(repo.working_dir, 'data_trajectories')
class TrajectoryDatasetBase(Dataset, abc.ABC):
def __init__(self,
dataset_subdir=None,
include_velocity=False,
normalizer='LimitsNormalizer',
use_extra_objects=False,
obstacle_cutoff_margin=None,
tensor_args=None,
**kwargs):
self.tensor_args = tensor_args
self.dataset_subdir = dataset_subdir
self.base_dir = os.path.join(dataset_base_dir, self.dataset_subdir)
self.args = load_params_from_yaml(os.path.join(self.base_dir, '0', 'args.yaml'))
self.metadata = load_params_from_yaml(os.path.join(self.base_dir, '0', 'metadata.yaml'))
if obstacle_cutoff_margin is not None:
self.args['obstacle_cutoff_margin'] = obstacle_cutoff_margin
# -------------------------------- Load env, robot, task ---------------------------------
# Environment
env_class = getattr(
environments, self.metadata['env_id'] + 'ExtraObjects' if use_extra_objects else self.metadata['env_id'])
self.env = env_class(tensor_args=tensor_args)
# Robot
robot_class = getattr(robots, self.metadata['robot_id'])
self.robot = robot_class(tensor_args=tensor_args)
# Task
self.task = PlanningTask(env=self.env, robot=self.robot, tensor_args=tensor_args, **self.args)
self.planner_visualizer = PlanningVisualizer(task=self.task)
# -------------------------------- Load trajectories ---------------------------------
self.threshold_start_goal_pos = self.args['threshold_start_goal_pos']
self.field_key_traj = 'traj'
self.field_key_task = 'task'
self.fields = {}
# load data
self.include_velocity = include_velocity
self.map_task_id_to_trajectories_id = {}
self.map_trajectory_id_to_task_id = {}
self.load_trajectories()
# dimensions
b, h, d = self.dataset_shape = self.fields[self.field_key_traj].shape
self.n_trajs = b
self.n_support_points = h
self.state_dim = d # state dimension used for the diffusion model
self.trajectory_dim = (self.n_support_points, d)
# normalize the data (for the diffusion model)
self.normalizer = DatasetNormalizer(self.fields, normalizer=normalizer)
self.normalizer_keys = [self.field_key_traj, self.field_key_task]
self.normalize_all_data(*self.normalizer_keys)
def load_trajectories(self):
# load free trajectories
trajs_free_l = []
task_id = 0
n_trajs = 0
for current_dir, subdirs, files in os.walk(self.base_dir, topdown=True):
if 'trajs-free.pt' in files:
trajs_free_tmp = torch.load(
os.path.join(current_dir, 'trajs-free.pt'), map_location=self.tensor_args['device'])
trajectories_idx = n_trajs + np.arange(len(trajs_free_tmp))
self.map_task_id_to_trajectories_id[task_id] = trajectories_idx
for j in trajectories_idx:
self.map_trajectory_id_to_task_id[j] = task_id
task_id += 1
n_trajs += len(trajs_free_tmp)
trajs_free_l.append(trajs_free_tmp)
trajs_free = torch.cat(trajs_free_l)
trajs_free_pos = self.robot.get_position(trajs_free)
if self.include_velocity:
trajs = trajs_free
else:
trajs = trajs_free_pos
self.fields[self.field_key_traj] = trajs
# task: start and goal state positions [n_trajectories, 2 * state_dim]
task = torch.cat((trajs_free_pos[..., 0, :], trajs_free_pos[..., -1, :]), dim=-1)
self.fields[self.field_key_task] = task
def normalize_all_data(self, *keys):
for key in keys:
self.fields[f'{key}_normalized'] = self.normalizer(self.fields[f'{key}'], key)
def render(self, task_id=3,
render_joint_trajectories=False,
render_robot_trajectories=False,
**kwargs):
# -------------------------------- Visualize ---------------------------------
idxs = self.map_task_id_to_trajectories_id[task_id]
pos_trajs = self.robot.get_position(self.fields[self.field_key_traj][idxs])
start_state_pos = pos_trajs[0][0]
goal_state_pos = pos_trajs[0][-1]
fig1, axs1, fig2, axs2 = [None] * 4
if render_joint_trajectories:
fig1, axs1 = self.planner_visualizer.plot_joint_space_state_trajectories(
trajs=pos_trajs,
pos_start_state=start_state_pos, pos_goal_state=goal_state_pos,
vel_start_state=torch.zeros_like(start_state_pos), vel_goal_state=torch.zeros_like(goal_state_pos),
)
if render_robot_trajectories:
fig2, axs2 = self.planner_visualizer.render_robot_trajectories(
trajs=pos_trajs, start_state=start_state_pos, goal_state=goal_state_pos,
)
return fig1, axs1, fig2, axs2
def __repr__(self):
msg = f'TrajectoryDataset\n' \
f'n_trajs: {self.n_trajs}\n' \
f'trajectory_dim: {self.trajectory_dim}\n'
return msg
def __len__(self):
return self.n_trajs
def __getitem__(self, index):
# Generates one sample of data - one trajectory and tasks
field_traj_normalized = f'{self.field_key_traj}_normalized'
field_task_normalized = f'{self.field_key_task}_normalized'
traj_normalized = self.fields[field_traj_normalized][index]
task_normalized = self.fields[field_task_normalized][index]
data = {
field_traj_normalized: traj_normalized,
field_task_normalized: task_normalized
}
# build hard conditions
hard_conds = self.get_hard_conditions(traj_normalized, horizon=len(traj_normalized))
data.update({'hard_conds': hard_conds})
return data
def get_hard_conditions(self, traj, horizon=None, normalize=False):
raise NotImplementedError
def get_unnormalized(self, index):
raise NotImplementedError
traj = self.fields[self.field_key_traj][index][..., :self.state_dim]
task = self.fields[self.field_key_task][index]
if not self.include_velocity:
task = task[self.task_idxs]
data = {self.field_key_traj: traj,
self.field_key_task: task,
}
if self.variable_environment:
data.update({self.field_key_env: self.fields[self.field_key_env][index]})
# hard conditions
# hard_conds = self.get_hard_conds(tasks)
hard_conds = self.get_hard_conditions(traj)
data.update({'hard_conds': hard_conds})
return data
def unnormalize(self, x, key):
return self.normalizer.unnormalize(x, key)
def normalize(self, x, key):
return self.normalizer.normalize(x, key)
def unnormalize_trajectories(self, x):
return self.unnormalize(x, self.field_key_traj)
def normalize_trajectories(self, x):
return self.normalize(x, self.field_key_traj)
def unnormalize_tasks(self, x):
return self.unnormalize(x, self.field_key_task)
def normalize_tasks(self, x):
return self.normalize(x, self.field_key_task)
class TrajectoryDataset(TrajectoryDatasetBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_hard_conditions(self, traj, horizon=None, normalize=False):
# start and goal positions
start_state_pos = self.robot.get_position(traj[0])
goal_state_pos = self.robot.get_position(traj[-1])
if self.include_velocity:
# If velocities are part of the state, then set them to zero at the beggining and end of a trajectory
start_state = torch.cat((start_state_pos, torch.zeros_like(start_state_pos)), dim=-1)
goal_state = torch.cat((goal_state_pos, torch.zeros_like(goal_state_pos)), dim=-1)
else:
start_state = start_state_pos
goal_state = goal_state_pos
if normalize:
start_state = self.normalizer.normalize(start_state, key=self.field_key_traj)
goal_state = self.normalizer.normalize(goal_state, key=self.field_key_traj)
if horizon is None:
horizon = self.n_support_points
hard_conds = {
0: start_state,
horizon - 1: goal_state
}
return hard_conds