mpd-public/mpd/summaries/summary_trajectory_generation.py
2023-10-23 15:45:14 +02:00

98 lines
4.0 KiB
Python

import einops
import matplotlib.pyplot as plt
import numpy as np
import torch
import wandb
from mpd.models import build_context
from mpd.summaries.summary_base import SummaryBase
class SummaryTrajectoryGeneration(SummaryBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def summary_fn(self, train_step=None, model=None, datasubset=None, prefix='', debug=False, **kwargs):
dataset = datasubset.dataset
# ------------------------------------------------------------------------------------
# Get a random task from the dataset
trajectory_id = np.random.choice(datasubset.indices)
task_id = dataset.map_trajectory_id_to_task_id[trajectory_id]
data_normalized = dataset[trajectory_id]
context = build_context(model, dataset, data_normalized)
# ------------------------------------------------------------------------------------
# Sample trajectories with the diffusion/cvae model
n_samples = 25
horizon = dataset.n_support_points
hard_conds = data_normalized['hard_conds']
trajs_normalized = model.run_inference(
context, hard_conds,
n_samples=n_samples, horizon=horizon,
deterministic_steps=0
)
# unnormalize trajectory samples from the diffusion model
trajs = dataset.unnormalize(trajs_normalized, dataset.field_key_traj)
# ------------------------------------------------------------------------------------
# STATISTICS
wandb.log({f'{prefix}percentage free trajs': dataset.task.compute_fraction_free_trajs(trajs)}, step=train_step)
wandb.log({f'{prefix}percentage collision intensity': dataset.task.compute_collision_intensity_trajs(trajs)}, step=train_step)
wandb.log({f'{prefix}success': dataset.task.compute_success_free_trajs(trajs)}, step=train_step)
# ------------------------------------------------------------------------------------
# Render
# dataset trajectory
fig_joint_trajs_dataset, _, fig_robot_trajs_dataset, _ = dataset.render(
task_id=task_id,
render_joint_trajectories=True
)
# diffusion trajectory
pos_trajs = dataset.robot.get_position(trajs)
start_state_pos = pos_trajs[0][0]
goal_state_pos = pos_trajs[0][-1]
fig_joint_trajs_diffusion, _ = dataset.planner_visualizer.plot_joint_space_state_trajectories(
trajs=pos_trajs,
pos_start_state=start_state_pos, pos_goal_state=goal_state_pos,
vel_start_state=torch.zeros_like(start_state_pos), vel_goal_state=torch.zeros_like(goal_state_pos),
linestyle='dashed'
)
fig_robot_trajs_diffusion = None
# fig_robot_trajs_diffusion, _ = dataset.planner_visualizer.render_robot_trajectories(
# trajs=pos_trajs, start_state=start_state_pos, goal_state=goal_state_pos,
# linestyle='dashed'
# )
if fig_joint_trajs_dataset is not None:
wandb.log({f"{prefix}joint trajectories DATASET": wandb.Image(fig_joint_trajs_dataset)}, step=train_step)
if fig_robot_trajs_dataset is not None:
wandb.log({f"{prefix}robot trajectories DATASET": wandb.Image(fig_robot_trajs_dataset)}, step=train_step)
if fig_joint_trajs_diffusion is not None:
wandb.log({f"{prefix}joint trajectories DIFFUSION": wandb.Image(fig_joint_trajs_diffusion)}, step=train_step)
if fig_robot_trajs_diffusion is not None:
wandb.log({f"{prefix}robot trajectories DIFFUSION": wandb.Image(fig_robot_trajs_diffusion)}, step=train_step)
if debug:
plt.show()
if fig_joint_trajs_dataset is not None:
plt.close(fig_joint_trajs_dataset)
if fig_robot_trajs_dataset is not None:
plt.close(fig_robot_trajs_dataset)
if fig_joint_trajs_diffusion is not None:
plt.close(fig_joint_trajs_diffusion)
if fig_robot_trajs_diffusion is not None:
plt.close(fig_robot_trajs_diffusion)