diffusion_policy/demo_pusht.py
2023-06-07 00:00:04 -04:00

121 lines
4.2 KiB
Python

import numpy as np
import click
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
import pygame
@click.command()
@click.option('-o', '--output', required=True)
@click.option('-rs', '--render_size', default=96, type=int)
@click.option('-hz', '--control_hz', default=10, type=int)
def main(output, render_size, control_hz):
"""
Collect demonstration for the Push-T task.
Usage: python demo_pusht.py -o data/pusht_demo.zarr
This script is compatible with both Linux and MacOS.
Hover mouse close to the blue circle to start.
Push the T block into the green area.
The episode will automatically terminate if the task is succeeded.
Press "Q" to exit.
Press "R" to retry.
Hold "Space" to pause.
"""
# create replay buffer in read-write mode
replay_buffer = ReplayBuffer.create_from_path(output, mode='a')
# create PushT env with keypoints
kp_kwargs = PushTKeypointsEnv.genenerate_keypoint_manager_params()
env = PushTKeypointsEnv(render_size=render_size, render_action=False, **kp_kwargs)
agent = env.teleop_agent()
clock = pygame.time.Clock()
# episode-level while loop
while True:
episode = list()
# record in seed order, starting with 0
seed = replay_buffer.n_episodes
print(f'starting seed {seed}')
# set seed for env
env.seed(seed)
# reset env and get observations (including info and render for recording)
obs = env.reset()
info = env._get_info()
img = env.render(mode='human')
# loop state
retry = False
pause = False
done = False
plan_idx = 0
pygame.display.set_caption(f'plan_idx:{plan_idx}')
# step-level while loop
while not done:
# process keypress events
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
# hold Space to pause
plan_idx += 1
pygame.display.set_caption(f'plan_idx:{plan_idx}')
pause = True
elif event.key == pygame.K_r:
# press "R" to retry
retry=True
elif event.key == pygame.K_q:
# press "Q" to exit
exit(0)
if event.type == pygame.KEYUP:
if event.key == pygame.K_SPACE:
pause = False
# handle control flow
if retry:
break
if pause:
continue
# get action from mouse
# None if mouse is not close to the agent
act = agent.act(obs)
if not act is None:
# teleop started
# state dim 2+3
state = np.concatenate([info['pos_agent'], info['block_pose']])
# discard unused information such as visibility mask and agent pos
# for compatibility
keypoint = obs.reshape(2,-1)[0].reshape(-1,2)[:9]
data = {
'img': img,
'state': np.float32(state),
'keypoint': np.float32(keypoint),
'action': np.float32(act),
'n_contacts': np.float32([info['n_contacts']])
}
episode.append(data)
# step env and render
obs, reward, done, info = env.step(act)
img = env.render(mode='human')
# regulate control frequency
clock.tick(control_hz)
if not retry:
# save episode buffer to replay buffer (on disk)
data_dict = dict()
for key in episode[0].keys():
data_dict[key] = np.stack(
[x[key] for x in episode])
replay_buffer.add_episode(data_dict, compressors='disk')
print(f'saved seed {seed}')
else:
print(f'retry seed {seed}')
if __name__ == "__main__":
main()