Compare commits
12 Commits
cchi/bug_f
...
main
Author | SHA1 | Date | |
---|---|---|---|
|
548a52bbb1 | ||
|
de4384e84a | ||
|
7dd9dc417a | ||
|
5aa9996fdc | ||
|
5c3d54fca3 | ||
|
a98e74873b | ||
|
68eef44d3e | ||
|
c52bac42ee | ||
|
749db2ce9c | ||
|
dd2cbac9fa | ||
|
0d00e02b45 | ||
|
74b6391737 |
35
README.md
35
README.md
@ -202,6 +202,41 @@ data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image
|
||||
|
||||
7 directories, 16 files
|
||||
```
|
||||
### 🆕 Evaluate Pre-trained Checkpoints
|
||||
Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).
|
||||
|
||||
Run the evaluation script:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
|
||||
```
|
||||
|
||||
This will generate the following directory structure:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ tree data/pusht_eval_output
|
||||
data/pusht_eval_output
|
||||
├── eval_log.json
|
||||
└── media
|
||||
├── 1fxtno84.mp4
|
||||
├── 224l7jqd.mp4
|
||||
├── 2fo4btlf.mp4
|
||||
├── 2in4cn7a.mp4
|
||||
├── 34b3o2qq.mp4
|
||||
└── 3p7jqn32.mp4
|
||||
|
||||
1 directory, 7 files
|
||||
```
|
||||
|
||||
`eval_log.json` contains metrics that is logged to wandb during training:
|
||||
```console
|
||||
(robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
|
||||
{
|
||||
"test/mean_score": 0.9150393806777066,
|
||||
"test/sim_max_reward_4300000": 1.0,
|
||||
"test/sim_max_reward_4300001": 0.9872969750774386,
|
||||
...
|
||||
"train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
|
||||
}
|
||||
```
|
||||
|
||||
## 🦾 Demo, Training and Eval on a Real Robot
|
||||
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).
|
||||
|
@ -46,6 +46,8 @@ dependencies:
|
||||
- diffusers=0.11.1
|
||||
- av=10.0.0
|
||||
- cmake=3.24.3
|
||||
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
||||
- llvm-openmp=14
|
||||
# trick to force reinstall imagecodecs via pip
|
||||
- imagecodecs==2022.8.8
|
||||
- pip:
|
||||
|
@ -46,6 +46,8 @@ dependencies:
|
||||
- diffusers=0.11.1
|
||||
- av=10.0.0
|
||||
- cmake=3.24.3
|
||||
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
||||
- llvm-openmp=14
|
||||
# trick to force reinstall imagecodecs via pip
|
||||
- imagecodecs==2022.8.8
|
||||
- pip:
|
||||
|
120
demo_pusht.py
Normal file
120
demo_pusht.py
Normal file
@ -0,0 +1,120 @@
|
||||
import numpy as np
|
||||
import click
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
|
||||
import pygame
|
||||
|
||||
@click.command()
|
||||
@click.option('-o', '--output', required=True)
|
||||
@click.option('-rs', '--render_size', default=96, type=int)
|
||||
@click.option('-hz', '--control_hz', default=10, type=int)
|
||||
def main(output, render_size, control_hz):
|
||||
"""
|
||||
Collect demonstration for the Push-T task.
|
||||
|
||||
Usage: python demo_pusht.py -o data/pusht_demo.zarr
|
||||
|
||||
This script is compatible with both Linux and MacOS.
|
||||
Hover mouse close to the blue circle to start.
|
||||
Push the T block into the green area.
|
||||
The episode will automatically terminate if the task is succeeded.
|
||||
Press "Q" to exit.
|
||||
Press "R" to retry.
|
||||
Hold "Space" to pause.
|
||||
"""
|
||||
|
||||
# create replay buffer in read-write mode
|
||||
replay_buffer = ReplayBuffer.create_from_path(output, mode='a')
|
||||
|
||||
# create PushT env with keypoints
|
||||
kp_kwargs = PushTKeypointsEnv.genenerate_keypoint_manager_params()
|
||||
env = PushTKeypointsEnv(render_size=render_size, render_action=False, **kp_kwargs)
|
||||
agent = env.teleop_agent()
|
||||
clock = pygame.time.Clock()
|
||||
|
||||
# episode-level while loop
|
||||
while True:
|
||||
episode = list()
|
||||
# record in seed order, starting with 0
|
||||
seed = replay_buffer.n_episodes
|
||||
print(f'starting seed {seed}')
|
||||
|
||||
# set seed for env
|
||||
env.seed(seed)
|
||||
|
||||
# reset env and get observations (including info and render for recording)
|
||||
obs = env.reset()
|
||||
info = env._get_info()
|
||||
img = env.render(mode='human')
|
||||
|
||||
# loop state
|
||||
retry = False
|
||||
pause = False
|
||||
done = False
|
||||
plan_idx = 0
|
||||
pygame.display.set_caption(f'plan_idx:{plan_idx}')
|
||||
# step-level while loop
|
||||
while not done:
|
||||
# process keypress events
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_SPACE:
|
||||
# hold Space to pause
|
||||
plan_idx += 1
|
||||
pygame.display.set_caption(f'plan_idx:{plan_idx}')
|
||||
pause = True
|
||||
elif event.key == pygame.K_r:
|
||||
# press "R" to retry
|
||||
retry=True
|
||||
elif event.key == pygame.K_q:
|
||||
# press "Q" to exit
|
||||
exit(0)
|
||||
if event.type == pygame.KEYUP:
|
||||
if event.key == pygame.K_SPACE:
|
||||
pause = False
|
||||
|
||||
# handle control flow
|
||||
if retry:
|
||||
break
|
||||
if pause:
|
||||
continue
|
||||
|
||||
# get action from mouse
|
||||
# None if mouse is not close to the agent
|
||||
act = agent.act(obs)
|
||||
if not act is None:
|
||||
# teleop started
|
||||
# state dim 2+3
|
||||
state = np.concatenate([info['pos_agent'], info['block_pose']])
|
||||
# discard unused information such as visibility mask and agent pos
|
||||
# for compatibility
|
||||
keypoint = obs.reshape(2,-1)[0].reshape(-1,2)[:9]
|
||||
data = {
|
||||
'img': img,
|
||||
'state': np.float32(state),
|
||||
'keypoint': np.float32(keypoint),
|
||||
'action': np.float32(act),
|
||||
'n_contacts': np.float32([info['n_contacts']])
|
||||
}
|
||||
episode.append(data)
|
||||
|
||||
# step env and render
|
||||
obs, reward, done, info = env.step(act)
|
||||
img = env.render(mode='human')
|
||||
|
||||
# regulate control frequency
|
||||
clock.tick(control_hz)
|
||||
if not retry:
|
||||
# save episode buffer to replay buffer (on disk)
|
||||
data_dict = dict()
|
||||
for key in episode[0].keys():
|
||||
data_dict[key] = np.stack(
|
||||
[x[key] for x in episode])
|
||||
replay_buffer.add_episode(data_dict, compressors='disk')
|
||||
print(f'saved seed {seed}')
|
||||
else:
|
||||
print(f'retry seed {seed}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -40,7 +40,7 @@ class RotationTransformer:
|
||||
getattr(pt, f'matrix_to_{from_rep}')
|
||||
]
|
||||
if from_convention is not None:
|
||||
funcs = [functools.partial(func, convernsion=from_convention)
|
||||
funcs = [functools.partial(func, convention=from_convention)
|
||||
for func in funcs]
|
||||
forward_funcs.append(funcs[0])
|
||||
inverse_funcs.append(funcs[1])
|
||||
@ -51,7 +51,7 @@ class RotationTransformer:
|
||||
getattr(pt, f'{to_rep}_to_matrix')
|
||||
]
|
||||
if to_convention is not None:
|
||||
funcs = [functools.partial(func, convernsion=to_convention)
|
||||
funcs = [functools.partial(func, convention=to_convention)
|
||||
for func in funcs]
|
||||
forward_funcs.append(funcs[0])
|
||||
inverse_funcs.append(funcs[1])
|
||||
|
@ -226,6 +226,10 @@ class ConditionalUnet1D(nn.Module):
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
# The correct condition should be:
|
||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||
# However this change will break compatibility with published checkpoints.
|
||||
# Therefore it is left as a comment.
|
||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||
x = x + h_local[1]
|
||||
x = resnet2(x, global_feature)
|
||||
|
@ -256,8 +256,8 @@ class DiffusionTransformerHybridImagePolicy(BaseImagePolicy):
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
nobs_features = nobs_features.reshape(B, T, -1)
|
||||
# reshape back to B, To, Do
|
||||
nobs_features = nobs_features.reshape(B, To, -1)
|
||||
shape = (B, T, Da+Do)
|
||||
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
|
@ -247,7 +247,7 @@ class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
# reshape back to B, To, Do
|
||||
nobs_features = nobs_features.reshape(B, To, -1)
|
||||
cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
|
64
eval.py
Normal file
64
eval.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
Usage:
|
||||
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
|
||||
"""
|
||||
|
||||
import sys
|
||||
# use line-buffering for both stdout and stderr
|
||||
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
|
||||
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import click
|
||||
import hydra
|
||||
import torch
|
||||
import dill
|
||||
import wandb
|
||||
import json
|
||||
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
||||
|
||||
@click.command()
|
||||
@click.option('-c', '--checkpoint', required=True)
|
||||
@click.option('-o', '--output_dir', required=True)
|
||||
@click.option('-d', '--device', default='cuda:0')
|
||||
def main(checkpoint, output_dir, device):
|
||||
if os.path.exists(output_dir):
|
||||
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
|
||||
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# load checkpoint
|
||||
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
|
||||
cfg = payload['cfg']
|
||||
cls = hydra.utils.get_class(cfg._target_)
|
||||
workspace = cls(cfg, output_dir=output_dir)
|
||||
workspace: BaseWorkspace
|
||||
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
|
||||
|
||||
# get policy from workspace
|
||||
policy = workspace.model
|
||||
if cfg.training.use_ema:
|
||||
policy = workspace.ema_model
|
||||
|
||||
device = torch.device(device)
|
||||
policy.to(device)
|
||||
policy.eval()
|
||||
|
||||
# run eval
|
||||
env_runner = hydra.utils.instantiate(
|
||||
cfg.task.env_runner,
|
||||
output_dir=output_dir)
|
||||
runner_log = env_runner.run(policy)
|
||||
|
||||
# dump log to json
|
||||
json_log = dict()
|
||||
for key, value in runner_log.items():
|
||||
if isinstance(value, wandb.sdk.data_types.video.Video):
|
||||
json_log[key] = value._path
|
||||
else:
|
||||
json_log[key] = value
|
||||
out_path = os.path.join(output_dir, 'eval_log.json')
|
||||
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user