2023-03-07 16:07:15 -05:00

436 lines
15 KiB
Python

from typing import Optional
import pathlib
import numpy as np
import time
import shutil
import math
from multiprocessing.managers import SharedMemoryManager
from diffusion_policy.real_world.rtde_interpolation_controller import RTDEInterpolationController
from diffusion_policy.real_world.multi_realsense import MultiRealsense, SingleRealsense
from diffusion_policy.real_world.video_recorder import VideoRecorder
from diffusion_policy.common.timestamp_accumulator import (
TimestampObsAccumulator,
TimestampActionAccumulator,
align_timestamps
)
from diffusion_policy.real_world.multi_camera_visualizer import MultiCameraVisualizer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.common.cv2_util import (
get_image_transform, optimal_row_cols)
DEFAULT_OBS_KEY_MAP = {
# robot
'ActualTCPPose': 'robot_eef_pose',
'ActualTCPSpeed': 'robot_eef_pose_vel',
'ActualQ': 'robot_joint',
'ActualQd': 'robot_joint_vel',
# timestamps
'step_idx': 'step_idx',
'timestamp': 'timestamp'
}
class RealEnv:
def __init__(self,
# required params
output_dir,
robot_ip,
# env params
frequency=10,
n_obs_steps=2,
# obs
obs_image_resolution=(640,480),
max_obs_buffer_size=30,
camera_serial_numbers=None,
obs_key_map=DEFAULT_OBS_KEY_MAP,
obs_float32=False,
# action
max_pos_speed=0.25,
max_rot_speed=0.6,
# robot
tcp_offset=0.13,
init_joints=False,
# video capture params
video_capture_fps=30,
video_capture_resolution=(1280,720),
# saving params
record_raw_video=True,
thread_per_video=2,
video_crf=21,
# vis params
enable_multi_cam_vis=True,
multi_cam_vis_resolution=(1280,720),
# shared memory
shm_manager=None
):
assert frequency <= video_capture_fps
output_dir = pathlib.Path(output_dir)
assert output_dir.parent.is_dir()
video_dir = output_dir.joinpath('videos')
video_dir.mkdir(parents=True, exist_ok=True)
zarr_path = str(output_dir.joinpath('replay_buffer.zarr').absolute())
replay_buffer = ReplayBuffer.create_from_path(
zarr_path=zarr_path, mode='a')
if shm_manager is None:
shm_manager = SharedMemoryManager()
shm_manager.start()
if camera_serial_numbers is None:
camera_serial_numbers = SingleRealsense.get_connected_devices_serial()
color_tf = get_image_transform(
input_res=video_capture_resolution,
output_res=obs_image_resolution,
# obs output rgb
bgr_to_rgb=True)
color_transform = color_tf
if obs_float32:
color_transform = lambda x: color_tf(x).astype(np.float32) / 255
def transform(data):
data['color'] = color_transform(data['color'])
return data
rw, rh, col, row = optimal_row_cols(
n_cameras=len(camera_serial_numbers),
in_wh_ratio=obs_image_resolution[0]/obs_image_resolution[1],
max_resolution=multi_cam_vis_resolution
)
vis_color_transform = get_image_transform(
input_res=video_capture_resolution,
output_res=(rw,rh),
bgr_to_rgb=False
)
def vis_transform(data):
data['color'] = vis_color_transform(data['color'])
return data
recording_transfrom = None
recording_fps = video_capture_fps
recording_pix_fmt = 'bgr24'
if not record_raw_video:
recording_transfrom = transform
recording_fps = frequency
recording_pix_fmt = 'rgb24'
video_recorder = VideoRecorder.create_h264(
fps=recording_fps,
codec='h264',
input_pix_fmt=recording_pix_fmt,
crf=video_crf,
thread_type='FRAME',
thread_count=thread_per_video)
realsense = MultiRealsense(
serial_numbers=camera_serial_numbers,
shm_manager=shm_manager,
resolution=video_capture_resolution,
capture_fps=video_capture_fps,
put_fps=video_capture_fps,
# send every frame immediately after arrival
# ignores put_fps
put_downsample=False,
record_fps=recording_fps,
enable_color=True,
enable_depth=False,
enable_infrared=False,
get_max_k=max_obs_buffer_size,
transform=transform,
vis_transform=vis_transform,
recording_transform=recording_transfrom,
video_recorder=video_recorder,
verbose=False
)
multi_cam_vis = None
if enable_multi_cam_vis:
multi_cam_vis = MultiCameraVisualizer(
realsense=realsense,
row=row,
col=col,
rgb_to_bgr=False
)
cube_diag = np.linalg.norm([1,1,1])
j_init = np.array([0,-90,-90,-90,90,0]) / 180 * np.pi
if not init_joints:
j_init = None
robot = RTDEInterpolationController(
shm_manager=shm_manager,
robot_ip=robot_ip,
frequency=125, # UR5 CB3 RTDE
lookahead_time=0.1,
gain=300,
max_pos_speed=max_pos_speed*cube_diag,
max_rot_speed=max_rot_speed*cube_diag,
launch_timeout=3,
tcp_offset_pose=[0,0,tcp_offset,0,0,0],
payload_mass=None,
payload_cog=None,
joints_init=j_init,
joints_init_speed=1.05,
soft_real_time=False,
verbose=False,
receive_keys=None,
get_max_k=max_obs_buffer_size
)
self.realsense = realsense
self.robot = robot
self.multi_cam_vis = multi_cam_vis
self.video_capture_fps = video_capture_fps
self.frequency = frequency
self.n_obs_steps = n_obs_steps
self.max_obs_buffer_size = max_obs_buffer_size
self.max_pos_speed = max_pos_speed
self.max_rot_speed = max_rot_speed
self.obs_key_map = obs_key_map
# recording
self.output_dir = output_dir
self.video_dir = video_dir
self.replay_buffer = replay_buffer
# temp memory buffers
self.last_realsense_data = None
# recording buffers
self.obs_accumulator = None
self.action_accumulator = None
self.stage_accumulator = None
self.start_time = None
# ======== start-stop API =============
@property
def is_ready(self):
return self.realsense.is_ready and self.robot.is_ready
def start(self, wait=True):
self.realsense.start(wait=False)
self.robot.start(wait=False)
if self.multi_cam_vis is not None:
self.multi_cam_vis.start(wait=False)
if wait:
self.start_wait()
def stop(self, wait=True):
self.end_episode()
if self.multi_cam_vis is not None:
self.multi_cam_vis.stop(wait=False)
self.robot.stop(wait=False)
self.realsense.stop(wait=False)
if wait:
self.stop_wait()
def start_wait(self):
self.realsense.start_wait()
self.robot.start_wait()
if self.multi_cam_vis is not None:
self.multi_cam_vis.start_wait()
def stop_wait(self):
self.robot.stop_wait()
self.realsense.stop_wait()
if self.multi_cam_vis is not None:
self.multi_cam_vis.stop_wait()
# ========= context manager ===========
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
# ========= async env API ===========
def get_obs(self) -> dict:
"observation dict"
assert self.is_ready
# get data
# 30 Hz, camera_receive_timestamp
k = math.ceil(self.n_obs_steps * (self.video_capture_fps / self.frequency))
self.last_realsense_data = self.realsense.get(
k=k,
out=self.last_realsense_data)
# 125 hz, robot_receive_timestamp
last_robot_data = self.robot.get_all_state()
# both have more than n_obs_steps data
# align camera obs timestamps
dt = 1 / self.frequency
last_timestamp = np.max([x['timestamp'][-1] for x in self.last_realsense_data.values()])
obs_align_timestamps = last_timestamp - (np.arange(self.n_obs_steps)[::-1] * dt)
camera_obs = dict()
for camera_idx, value in self.last_realsense_data.items():
this_timestamps = value['timestamp']
this_idxs = list()
for t in obs_align_timestamps:
is_before_idxs = np.nonzero(this_timestamps < t)[0]
this_idx = 0
if len(is_before_idxs) > 0:
this_idx = is_before_idxs[-1]
this_idxs.append(this_idx)
# remap key
camera_obs[f'camera_{camera_idx}'] = value['color'][this_idxs]
# align robot obs
robot_timestamps = last_robot_data['robot_receive_timestamp']
this_timestamps = robot_timestamps
this_idxs = list()
for t in obs_align_timestamps:
is_before_idxs = np.nonzero(this_timestamps < t)[0]
this_idx = 0
if len(is_before_idxs) > 0:
this_idx = is_before_idxs[-1]
this_idxs.append(this_idx)
robot_obs_raw = dict()
for k, v in last_robot_data.items():
if k in self.obs_key_map:
robot_obs_raw[self.obs_key_map[k]] = v
robot_obs = dict()
for k, v in robot_obs_raw.items():
robot_obs[k] = v[this_idxs]
# accumulate obs
if self.obs_accumulator is not None:
self.obs_accumulator.put(
robot_obs_raw,
robot_timestamps
)
# return obs
obs_data = dict(camera_obs)
obs_data.update(robot_obs)
obs_data['timestamp'] = obs_align_timestamps
return obs_data
def exec_actions(self,
actions: np.ndarray,
timestamps: np.ndarray,
stages: Optional[np.ndarray]=None):
assert self.is_ready
if not isinstance(actions, np.ndarray):
actions = np.array(actions)
if not isinstance(timestamps, np.ndarray):
timestamps = np.array(timestamps)
if stages is None:
stages = np.zeros_like(timestamps, dtype=np.int64)
elif not isinstance(stages, np.ndarray):
stages = np.array(stages, dtype=np.int64)
# convert action to pose
receive_time = time.time()
is_new = timestamps > receive_time
new_actions = actions[is_new]
new_timestamps = timestamps[is_new]
new_stages = stages[is_new]
# schedule waypoints
for i in range(len(new_actions)):
self.robot.schedule_waypoint(
pose=new_actions[i],
target_time=new_timestamps[i]
)
# record actions
if self.action_accumulator is not None:
self.action_accumulator.put(
new_actions,
new_timestamps
)
if self.stage_accumulator is not None:
self.stage_accumulator.put(
new_stages,
new_timestamps
)
def get_robot_state(self):
return self.robot.get_state()
# recording API
def start_episode(self, start_time=None):
"Start recording and return first obs"
if start_time is None:
start_time = time.time()
self.start_time = start_time
assert self.is_ready
# prepare recording stuff
episode_id = self.replay_buffer.n_episodes
this_video_dir = self.video_dir.joinpath(str(episode_id))
this_video_dir.mkdir(parents=True, exist_ok=True)
n_cameras = self.realsense.n_cameras
video_paths = list()
for i in range(n_cameras):
video_paths.append(
str(this_video_dir.joinpath(f'{i}.mp4').absolute()))
# start recording on realsense
self.realsense.restart_put(start_time=start_time)
self.realsense.start_recording(video_path=video_paths, start_time=start_time)
# create accumulators
self.obs_accumulator = TimestampObsAccumulator(
start_time=start_time,
dt=1/self.frequency
)
self.action_accumulator = TimestampActionAccumulator(
start_time=start_time,
dt=1/self.frequency
)
self.stage_accumulator = TimestampActionAccumulator(
start_time=start_time,
dt=1/self.frequency
)
print(f'Episode {episode_id} started!')
def end_episode(self):
"Stop recording"
assert self.is_ready
# stop video recorder
self.realsense.stop_recording()
if self.obs_accumulator is not None:
# recording
assert self.action_accumulator is not None
assert self.stage_accumulator is not None
# Since the only way to accumulate obs and action is by calling
# get_obs and exec_actions, which will be in the same thread.
# We don't need to worry new data come in here.
obs_data = self.obs_accumulator.data
obs_timestamps = self.obs_accumulator.timestamps
actions = self.action_accumulator.actions
action_timestamps = self.action_accumulator.timestamps
stages = self.stage_accumulator.actions
n_steps = min(len(obs_timestamps), len(action_timestamps))
if n_steps > 0:
episode = dict()
episode['timestamp'] = obs_timestamps[:n_steps]
episode['action'] = actions[:n_steps]
episode['stage'] = stages[:n_steps]
for key, value in obs_data.items():
episode[key] = value[:n_steps]
self.replay_buffer.add_episode(episode, compressors='disk')
episode_id = self.replay_buffer.n_episodes - 1
print(f'Episode {episode_id} saved!')
self.obs_accumulator = None
self.action_accumulator = None
self.stage_accumulator = None
def drop_episode(self):
self.end_episode()
self.replay_buffer.drop_episode()
episode_id = self.replay_buffer.n_episodes
this_video_dir = self.video_dir.joinpath(str(episode_id))
if this_video_dir.exists():
shutil.rmtree(str(this_video_dir))
print(f'Episode {episode_id} dropped!')