From 0d00e02b45e9e3f37f4eeb68bff076b68d9e9d44 Mon Sep 17 00:00:00 2001 From: Cheng Chi Date: Wed, 7 Jun 2023 00:00:04 -0400 Subject: [PATCH] added demo script for pusht --- demo_pusht.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 demo_pusht.py diff --git a/demo_pusht.py b/demo_pusht.py new file mode 100644 index 0000000..0e3cf63 --- /dev/null +++ b/demo_pusht.py @@ -0,0 +1,120 @@ +import numpy as np +import click +from diffusion_policy.common.replay_buffer import ReplayBuffer +from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv +import pygame + +@click.command() +@click.option('-o', '--output', required=True) +@click.option('-rs', '--render_size', default=96, type=int) +@click.option('-hz', '--control_hz', default=10, type=int) +def main(output, render_size, control_hz): + """ + Collect demonstration for the Push-T task. + + Usage: python demo_pusht.py -o data/pusht_demo.zarr + + This script is compatible with both Linux and MacOS. + Hover mouse close to the blue circle to start. + Push the T block into the green area. + The episode will automatically terminate if the task is succeeded. + Press "Q" to exit. + Press "R" to retry. + Hold "Space" to pause. + """ + + # create replay buffer in read-write mode + replay_buffer = ReplayBuffer.create_from_path(output, mode='a') + + # create PushT env with keypoints + kp_kwargs = PushTKeypointsEnv.genenerate_keypoint_manager_params() + env = PushTKeypointsEnv(render_size=render_size, render_action=False, **kp_kwargs) + agent = env.teleop_agent() + clock = pygame.time.Clock() + + # episode-level while loop + while True: + episode = list() + # record in seed order, starting with 0 + seed = replay_buffer.n_episodes + print(f'starting seed {seed}') + + # set seed for env + env.seed(seed) + + # reset env and get observations (including info and render for recording) + obs = env.reset() + info = env._get_info() + img = env.render(mode='human') + + # loop state + retry = False + pause = False + done = False + plan_idx = 0 + pygame.display.set_caption(f'plan_idx:{plan_idx}') + # step-level while loop + while not done: + # process keypress events + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + if event.key == pygame.K_SPACE: + # hold Space to pause + plan_idx += 1 + pygame.display.set_caption(f'plan_idx:{plan_idx}') + pause = True + elif event.key == pygame.K_r: + # press "R" to retry + retry=True + elif event.key == pygame.K_q: + # press "Q" to exit + exit(0) + if event.type == pygame.KEYUP: + if event.key == pygame.K_SPACE: + pause = False + + # handle control flow + if retry: + break + if pause: + continue + + # get action from mouse + # None if mouse is not close to the agent + act = agent.act(obs) + if not act is None: + # teleop started + # state dim 2+3 + state = np.concatenate([info['pos_agent'], info['block_pose']]) + # discard unused information such as visibility mask and agent pos + # for compatibility + keypoint = obs.reshape(2,-1)[0].reshape(-1,2)[:9] + data = { + 'img': img, + 'state': np.float32(state), + 'keypoint': np.float32(keypoint), + 'action': np.float32(act), + 'n_contacts': np.float32([info['n_contacts']]) + } + episode.append(data) + + # step env and render + obs, reward, done, info = env.step(act) + img = env.render(mode='human') + + # regulate control frequency + clock.tick(control_hz) + if not retry: + # save episode buffer to replay buffer (on disk) + data_dict = dict() + for key in episode[0].keys(): + data_dict[key] = np.stack( + [x[key] for x in episode]) + replay_buffer.add_episode(data_dict, compressors='disk') + print(f'saved seed {seed}') + else: + print(f'retry seed {seed}') + + +if __name__ == "__main__": + main()