Compare commits

..

12 Commits

Author SHA1 Message Date
Cheng Chi
548a52bbb1
Merge pull request #27 from pointW/main
Fix typo in rotation_transformer.py
2023-10-26 22:34:42 -07:00
Dian Wang
de4384e84a
fix typo in rotation_transformer.py 2023-10-25 10:43:25 -04:00
Cheng Chi
7dd9dc417a
Merge pull request #21 from columbia-ai-robotics/cchi/fix_cpu_affinity
pinned llvm-openmp version to avoid cpu affinity bug in pytorch
2023-09-12 23:36:52 -07:00
Cheng Chi
5aa9996fdc pinned llvm-openmp version to avoid cpu affinity bug in pytorch 2023-09-13 02:36:26 -04:00
Cheng Chi
5c3d54fca3
Merge pull request #20 from columbia-ai-robotics/cchi/eval_script
added eval script and documentation
2023-09-09 22:58:56 -07:00
Cheng Chi
a98e74873b added eval script and documentation 2023-09-10 01:58:04 -04:00
Cheng Chi
68eef44d3e
Merge pull request #19 from columbia-ai-robotics/cchi/fix_transformer_impainting
fixed T->To based on suggestion from Dominique-Yiu
2023-09-09 09:52:24 -07:00
Cheng Chi
c52bac42ee fixed T->To based on suggestion from Dominique-Yiu 2023-09-09 12:51:49 -04:00
Cheng Chi
749db2ce9c
Merge pull request #18 from columbia-ai-robotics/cchi/bug_fix_unet1d
incorporated change from PR #10
2023-09-07 16:03:03 -07:00
Cheng Chi
dd2cbac9fa incorporated change from PR #10 2023-09-07 19:02:13 -04:00
Cheng Chi
0d00e02b45 added demo script for pusht 2023-06-07 00:00:04 -04:00
Cheng Chi
74b6391737
Merge pull request #7 from columbia-ai-robotics/cchi/bug_fix_eval_sample
fixed bug where only n_envs samples of metrics are used
2023-06-01 11:09:50 -04:00
9 changed files with 232 additions and 5 deletions

View File

@ -202,6 +202,41 @@ data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image
7 directories, 16 files
```
### 🆕 Evaluate Pre-trained Checkpoints
Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).
Run the evaluation script:
```console
(robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
```
This will generate the following directory structure:
```console
(robodiff)[diffusion_policy]$ tree data/pusht_eval_output
data/pusht_eval_output
├── eval_log.json
└── media
├── 1fxtno84.mp4
├── 224l7jqd.mp4
├── 2fo4btlf.mp4
├── 2in4cn7a.mp4
├── 34b3o2qq.mp4
└── 3p7jqn32.mp4
1 directory, 7 files
```
`eval_log.json` contains metrics that is logged to wandb during training:
```console
(robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
{
"test/mean_score": 0.9150393806777066,
"test/sim_max_reward_4300000": 1.0,
"test/sim_max_reward_4300001": 0.9872969750774386,
...
"train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
}
```
## 🦾 Demo, Training and Eval on a Real Robot
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).

View File

@ -46,6 +46,8 @@ dependencies:
- diffusers=0.11.1
- av=10.0.0
- cmake=3.24.3
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
- llvm-openmp=14
# trick to force reinstall imagecodecs via pip
- imagecodecs==2022.8.8
- pip:

View File

@ -46,6 +46,8 @@ dependencies:
- diffusers=0.11.1
- av=10.0.0
- cmake=3.24.3
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
- llvm-openmp=14
# trick to force reinstall imagecodecs via pip
- imagecodecs==2022.8.8
- pip:

120
demo_pusht.py Normal file
View File

@ -0,0 +1,120 @@
import numpy as np
import click
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
import pygame
@click.command()
@click.option('-o', '--output', required=True)
@click.option('-rs', '--render_size', default=96, type=int)
@click.option('-hz', '--control_hz', default=10, type=int)
def main(output, render_size, control_hz):
"""
Collect demonstration for the Push-T task.
Usage: python demo_pusht.py -o data/pusht_demo.zarr
This script is compatible with both Linux and MacOS.
Hover mouse close to the blue circle to start.
Push the T block into the green area.
The episode will automatically terminate if the task is succeeded.
Press "Q" to exit.
Press "R" to retry.
Hold "Space" to pause.
"""
# create replay buffer in read-write mode
replay_buffer = ReplayBuffer.create_from_path(output, mode='a')
# create PushT env with keypoints
kp_kwargs = PushTKeypointsEnv.genenerate_keypoint_manager_params()
env = PushTKeypointsEnv(render_size=render_size, render_action=False, **kp_kwargs)
agent = env.teleop_agent()
clock = pygame.time.Clock()
# episode-level while loop
while True:
episode = list()
# record in seed order, starting with 0
seed = replay_buffer.n_episodes
print(f'starting seed {seed}')
# set seed for env
env.seed(seed)
# reset env and get observations (including info and render for recording)
obs = env.reset()
info = env._get_info()
img = env.render(mode='human')
# loop state
retry = False
pause = False
done = False
plan_idx = 0
pygame.display.set_caption(f'plan_idx:{plan_idx}')
# step-level while loop
while not done:
# process keypress events
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
# hold Space to pause
plan_idx += 1
pygame.display.set_caption(f'plan_idx:{plan_idx}')
pause = True
elif event.key == pygame.K_r:
# press "R" to retry
retry=True
elif event.key == pygame.K_q:
# press "Q" to exit
exit(0)
if event.type == pygame.KEYUP:
if event.key == pygame.K_SPACE:
pause = False
# handle control flow
if retry:
break
if pause:
continue
# get action from mouse
# None if mouse is not close to the agent
act = agent.act(obs)
if not act is None:
# teleop started
# state dim 2+3
state = np.concatenate([info['pos_agent'], info['block_pose']])
# discard unused information such as visibility mask and agent pos
# for compatibility
keypoint = obs.reshape(2,-1)[0].reshape(-1,2)[:9]
data = {
'img': img,
'state': np.float32(state),
'keypoint': np.float32(keypoint),
'action': np.float32(act),
'n_contacts': np.float32([info['n_contacts']])
}
episode.append(data)
# step env and render
obs, reward, done, info = env.step(act)
img = env.render(mode='human')
# regulate control frequency
clock.tick(control_hz)
if not retry:
# save episode buffer to replay buffer (on disk)
data_dict = dict()
for key in episode[0].keys():
data_dict[key] = np.stack(
[x[key] for x in episode])
replay_buffer.add_episode(data_dict, compressors='disk')
print(f'saved seed {seed}')
else:
print(f'retry seed {seed}')
if __name__ == "__main__":
main()

View File

@ -40,7 +40,7 @@ class RotationTransformer:
getattr(pt, f'matrix_to_{from_rep}')
]
if from_convention is not None:
funcs = [functools.partial(func, convernsion=from_convention)
funcs = [functools.partial(func, convention=from_convention)
for func in funcs]
forward_funcs.append(funcs[0])
inverse_funcs.append(funcs[1])
@ -51,7 +51,7 @@ class RotationTransformer:
getattr(pt, f'{to_rep}_to_matrix')
]
if to_convention is not None:
funcs = [functools.partial(func, convernsion=to_convention)
funcs = [functools.partial(func, convention=to_convention)
for func in funcs]
forward_funcs.append(funcs[0])
inverse_funcs.append(funcs[1])

View File

@ -226,6 +226,10 @@ class ConditionalUnet1D(nn.Module):
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
x = resnet2(x, global_feature)

View File

@ -256,8 +256,8 @@ class DiffusionTransformerHybridImagePolicy(BaseImagePolicy):
# condition through impainting
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, T, Do
nobs_features = nobs_features.reshape(B, T, -1)
# reshape back to B, To, Do
nobs_features = nobs_features.reshape(B, To, -1)
shape = (B, T, Da+Do)
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)

View File

@ -247,7 +247,7 @@ class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
# condition through impainting
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, T, Do
# reshape back to B, To, Do
nobs_features = nobs_features.reshape(B, To, -1)
cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)

64
eval.py Normal file
View File

@ -0,0 +1,64 @@
"""
Usage:
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
"""
import sys
# use line-buffering for both stdout and stderr
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
import os
import pathlib
import click
import hydra
import torch
import dill
import wandb
import json
from diffusion_policy.workspace.base_workspace import BaseWorkspace
@click.command()
@click.option('-c', '--checkpoint', required=True)
@click.option('-o', '--output_dir', required=True)
@click.option('-d', '--device', default='cuda:0')
def main(checkpoint, output_dir, device):
if os.path.exists(output_dir):
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
# load checkpoint
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
cfg = payload['cfg']
cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg, output_dir=output_dir)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
# get policy from workspace
policy = workspace.model
if cfg.training.use_ema:
policy = workspace.ema_model
device = torch.device(device)
policy.to(device)
policy.eval()
# run eval
env_runner = hydra.utils.instantiate(
cfg.task.env_runner,
output_dir=output_dir)
runner_log = env_runner.run(policy)
# dump log to json
json_log = dict()
for key, value in runner_log.items():
if isinstance(value, wandb.sdk.data_types.video.Video):
json_log[key] = value._path
else:
json_log[key] = value
out_path = os.path.join(output_dir, 'eval_log.json')
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
if __name__ == '__main__':
main()