98 lines
4.0 KiB
Python
98 lines
4.0 KiB
Python
import einops
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import wandb
|
|
|
|
from mpd.models import build_context
|
|
from mpd.summaries.summary_base import SummaryBase
|
|
|
|
|
|
class SummaryTrajectoryGeneration(SummaryBase):
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def summary_fn(self, train_step=None, model=None, datasubset=None, prefix='', debug=False, **kwargs):
|
|
|
|
dataset = datasubset.dataset
|
|
|
|
# ------------------------------------------------------------------------------------
|
|
# Get a random task from the dataset
|
|
trajectory_id = np.random.choice(datasubset.indices)
|
|
task_id = dataset.map_trajectory_id_to_task_id[trajectory_id]
|
|
|
|
data_normalized = dataset[trajectory_id]
|
|
context = build_context(model, dataset, data_normalized)
|
|
|
|
# ------------------------------------------------------------------------------------
|
|
# Sample trajectories with the diffusion/cvae model
|
|
n_samples = 25
|
|
horizon = dataset.n_support_points
|
|
hard_conds = data_normalized['hard_conds']
|
|
|
|
trajs_normalized = model.run_inference(
|
|
context, hard_conds,
|
|
n_samples=n_samples, horizon=horizon,
|
|
deterministic_steps=0
|
|
)
|
|
|
|
# unnormalize trajectory samples from the diffusion model
|
|
trajs = dataset.unnormalize(trajs_normalized, dataset.field_key_traj)
|
|
|
|
# ------------------------------------------------------------------------------------
|
|
# STATISTICS
|
|
wandb.log({f'{prefix}percentage free trajs': dataset.task.compute_fraction_free_trajs(trajs)}, step=train_step)
|
|
wandb.log({f'{prefix}percentage collision intensity': dataset.task.compute_collision_intensity_trajs(trajs)}, step=train_step)
|
|
wandb.log({f'{prefix}success': dataset.task.compute_success_free_trajs(trajs)}, step=train_step)
|
|
|
|
# ------------------------------------------------------------------------------------
|
|
# Render
|
|
|
|
# dataset trajectory
|
|
fig_joint_trajs_dataset, _, fig_robot_trajs_dataset, _ = dataset.render(
|
|
task_id=task_id,
|
|
render_joint_trajectories=True
|
|
)
|
|
|
|
# diffusion trajectory
|
|
pos_trajs = dataset.robot.get_position(trajs)
|
|
start_state_pos = pos_trajs[0][0]
|
|
goal_state_pos = pos_trajs[0][-1]
|
|
|
|
fig_joint_trajs_diffusion, _ = dataset.planner_visualizer.plot_joint_space_state_trajectories(
|
|
trajs=pos_trajs,
|
|
pos_start_state=start_state_pos, pos_goal_state=goal_state_pos,
|
|
vel_start_state=torch.zeros_like(start_state_pos), vel_goal_state=torch.zeros_like(goal_state_pos),
|
|
linestyle='dashed'
|
|
)
|
|
|
|
fig_robot_trajs_diffusion = None
|
|
# fig_robot_trajs_diffusion, _ = dataset.planner_visualizer.render_robot_trajectories(
|
|
# trajs=pos_trajs, start_state=start_state_pos, goal_state=goal_state_pos,
|
|
# linestyle='dashed'
|
|
# )
|
|
|
|
if fig_joint_trajs_dataset is not None:
|
|
wandb.log({f"{prefix}joint trajectories DATASET": wandb.Image(fig_joint_trajs_dataset)}, step=train_step)
|
|
if fig_robot_trajs_dataset is not None:
|
|
wandb.log({f"{prefix}robot trajectories DATASET": wandb.Image(fig_robot_trajs_dataset)}, step=train_step)
|
|
|
|
if fig_joint_trajs_diffusion is not None:
|
|
wandb.log({f"{prefix}joint trajectories DIFFUSION": wandb.Image(fig_joint_trajs_diffusion)}, step=train_step)
|
|
if fig_robot_trajs_diffusion is not None:
|
|
wandb.log({f"{prefix}robot trajectories DIFFUSION": wandb.Image(fig_robot_trajs_diffusion)}, step=train_step)
|
|
|
|
if debug:
|
|
plt.show()
|
|
|
|
if fig_joint_trajs_dataset is not None:
|
|
plt.close(fig_joint_trajs_dataset)
|
|
if fig_robot_trajs_dataset is not None:
|
|
plt.close(fig_robot_trajs_dataset)
|
|
if fig_joint_trajs_diffusion is not None:
|
|
plt.close(fig_joint_trajs_diffusion)
|
|
if fig_robot_trajs_diffusion is not None:
|
|
plt.close(fig_robot_trajs_diffusion)
|
|
|