diffusion_policy/diffusion_policy/env/pusht/pymunk_keypoint_manager.py
2023-03-07 16:07:15 -05:00

147 lines
5.1 KiB
Python

from typing import Dict, Sequence, Union, Optional
import numpy as np
import skimage.transform as st
import pymunk
import pygame
from matplotlib import cm
import cv2
from diffusion_policy.env.pusht.pymunk_override import DrawOptions
def farthest_point_sampling(points: np.ndarray, n_points: int, init_idx: int):
"""
Naive O(N^2)
"""
assert(n_points >= 1)
chosen_points = [points[init_idx]]
for _ in range(n_points-1):
cpoints = np.array(chosen_points)
all_dists = np.linalg.norm(points[:,None,:] - cpoints[None,:,:], axis=-1)
min_dists = all_dists.min(axis=1)
next_idx = np.argmax(min_dists)
next_pt = points[next_idx]
chosen_points.append(next_pt)
result = np.array(chosen_points)
return result
class PymunkKeypointManager:
def __init__(self,
local_keypoint_map: Dict[str, np.ndarray],
color_map: Optional[Dict[str, np.ndarray]]=None):
"""
local_keypoint_map:
"<attribute_name>": (N,2) floats in object local coordinate
"""
if color_map is None:
cmap = cm.get_cmap('tab10')
color_map = dict()
for i, key in enumerate(local_keypoint_map.keys()):
color_map[key] = (np.array(cmap.colors[i]) * 255).astype(np.uint8)
self.local_keypoint_map = local_keypoint_map
self.color_map = color_map
@property
def kwargs(self):
return {
'local_keypoint_map': self.local_keypoint_map,
'color_map': self.color_map
}
@classmethod
def create_from_pusht_env(cls, env, n_block_kps=9, n_agent_kps=3, seed=0, **kwargs):
rng = np.random.default_rng(seed=seed)
local_keypoint_map = dict()
for name in ['block','agent']:
self = env
self.space = pymunk.Space()
if name == 'agent':
self.agent = obj = self.add_circle((256, 400), 15)
n_kps = n_agent_kps
else:
self.block = obj = self.add_tee((256, 300), 0)
n_kps = n_block_kps
self.screen = pygame.Surface((512,512))
self.screen.fill(pygame.Color("white"))
draw_options = DrawOptions(self.screen)
self.space.debug_draw(draw_options)
# pygame.display.flip()
img = np.uint8(pygame.surfarray.array3d(self.screen).transpose(1, 0, 2))
obj_mask = (img != np.array([255,255,255],dtype=np.uint8)).any(axis=-1)
tf_img_obj = cls.get_tf_img_obj(obj)
xy_img = np.moveaxis(np.array(np.indices((512,512))), 0, -1)[:,:,::-1]
local_coord_img = tf_img_obj.inverse(xy_img.reshape(-1,2)).reshape(xy_img.shape)
obj_local_coords = local_coord_img[obj_mask]
# furthest point sampling
init_idx = rng.choice(len(obj_local_coords))
obj_local_kps = farthest_point_sampling(obj_local_coords, n_kps, init_idx)
small_shift = rng.uniform(0, 1, size=obj_local_kps.shape)
obj_local_kps += small_shift
local_keypoint_map[name] = obj_local_kps
return cls(local_keypoint_map=local_keypoint_map, **kwargs)
@staticmethod
def get_tf_img(pose: Sequence):
pos = pose[:2]
rot = pose[2]
tf_img_obj = st.AffineTransform(
translation=pos, rotation=rot)
return tf_img_obj
@classmethod
def get_tf_img_obj(cls, obj: pymunk.Body):
pose = tuple(obj.position) + (obj.angle,)
return cls.get_tf_img(pose)
def get_keypoints_global(self,
pose_map: Dict[set, Union[Sequence, pymunk.Body]],
is_obj=False):
kp_map = dict()
for key, value in pose_map.items():
if is_obj:
tf_img_obj = self.get_tf_img_obj(value)
else:
tf_img_obj = self.get_tf_img(value)
kp_local = self.local_keypoint_map[key]
kp_global = tf_img_obj(kp_local)
kp_map[key] = kp_global
return kp_map
def draw_keypoints(self, img, kps_map, radius=1):
scale = np.array(img.shape[:2]) / np.array([512,512])
for key, value in kps_map.items():
color = self.color_map[key].tolist()
coords = (value * scale).astype(np.int32)
for coord in coords:
cv2.circle(img, coord, radius=radius, color=color, thickness=-1)
return img
def draw_keypoints_pose(self, img, pose_map, is_obj=False, **kwargs):
kp_map = self.get_keypoints_global(pose_map, is_obj=is_obj)
return self.draw_keypoints(img, kps_map=kp_map, **kwargs)
def test():
from diffusion_policy.environment.push_t_env import PushTEnv
from matplotlib import pyplot as plt
env = PushTEnv(headless=True, obs_state=False, draw_action=False)
kp_manager = PymunkKeypointManager.create_from_pusht_env(env=env)
env.reset()
obj_map = {
'block': env.block,
'agent': env.agent
}
obs = env.render()
img = obs.astype(np.uint8)
kp_manager.draw_keypoints_pose(img=img, pose_map=obj_map, is_obj=True)
plt.imshow(img)