Merge pull request #20 from columbia-ai-robotics/cchi/eval_script
added eval script and documentation
This commit is contained in:
commit
5c3d54fca3
35
README.md
35
README.md
@ -202,6 +202,41 @@ data/outputs/2023.03.01/22.13.58_train_diffusion_unet_hybrid_pusht_image
|
|||||||
|
|
||||||
7 directories, 16 files
|
7 directories, 16 files
|
||||||
```
|
```
|
||||||
|
### 🆕 Evaluate Pre-trained Checkpoints
|
||||||
|
Download a checkpoint from the published training log folders, such as [https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt](https://diffusion-policy.cs.columbia.edu/data/experiments/low_dim/pusht/diffusion_policy_cnn/train_0/checkpoints/epoch=0550-test_mean_score=0.969.ckpt).
|
||||||
|
|
||||||
|
Run the evaluation script:
|
||||||
|
```console
|
||||||
|
(robodiff)[diffusion_policy]$ python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
|
||||||
|
```
|
||||||
|
|
||||||
|
This will generate the following directory structure:
|
||||||
|
```console
|
||||||
|
(robodiff)[diffusion_policy]$ tree data/pusht_eval_output
|
||||||
|
data/pusht_eval_output
|
||||||
|
├── eval_log.json
|
||||||
|
└── media
|
||||||
|
├── 1fxtno84.mp4
|
||||||
|
├── 224l7jqd.mp4
|
||||||
|
├── 2fo4btlf.mp4
|
||||||
|
├── 2in4cn7a.mp4
|
||||||
|
├── 34b3o2qq.mp4
|
||||||
|
└── 3p7jqn32.mp4
|
||||||
|
|
||||||
|
1 directory, 7 files
|
||||||
|
```
|
||||||
|
|
||||||
|
`eval_log.json` contains metrics that is logged to wandb during training:
|
||||||
|
```console
|
||||||
|
(robodiff)[diffusion_policy]$ cat data/pusht_eval_output/eval_log.json
|
||||||
|
{
|
||||||
|
"test/mean_score": 0.9150393806777066,
|
||||||
|
"test/sim_max_reward_4300000": 1.0,
|
||||||
|
"test/sim_max_reward_4300001": 0.9872969750774386,
|
||||||
|
...
|
||||||
|
"train/sim_video_1": "data/pusht_eval_output//media/2fo4btlf.mp4"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## 🦾 Demo, Training and Eval on a Real Robot
|
## 🦾 Demo, Training and Eval on a Real Robot
|
||||||
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).
|
Make sure your UR5 robot is running and accepting command from its network interface (emergency stop button within reach at all time), your RealSense cameras plugged in to your workstation (tested with `realsense-viewer`) and your SpaceMouse connected with the `spacenavd` daemon running (verify with `systemctl status spacenavd`).
|
||||||
|
64
eval.py
Normal file
64
eval.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
# use line-buffering for both stdout and stderr
|
||||||
|
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
|
||||||
|
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import click
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
import dill
|
||||||
|
import wandb
|
||||||
|
import json
|
||||||
|
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option('-c', '--checkpoint', required=True)
|
||||||
|
@click.option('-o', '--output_dir', required=True)
|
||||||
|
@click.option('-d', '--device', default='cuda:0')
|
||||||
|
def main(checkpoint, output_dir, device):
|
||||||
|
if os.path.exists(output_dir):
|
||||||
|
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
|
||||||
|
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# load checkpoint
|
||||||
|
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
|
||||||
|
cfg = payload['cfg']
|
||||||
|
cls = hydra.utils.get_class(cfg._target_)
|
||||||
|
workspace = cls(cfg, output_dir=output_dir)
|
||||||
|
workspace: BaseWorkspace
|
||||||
|
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
|
||||||
|
|
||||||
|
# get policy from workspace
|
||||||
|
policy = workspace.model
|
||||||
|
if cfg.training.use_ema:
|
||||||
|
policy = workspace.ema_model
|
||||||
|
|
||||||
|
device = torch.device(device)
|
||||||
|
policy.to(device)
|
||||||
|
policy.eval()
|
||||||
|
|
||||||
|
# run eval
|
||||||
|
env_runner = hydra.utils.instantiate(
|
||||||
|
cfg.task.env_runner,
|
||||||
|
output_dir=output_dir)
|
||||||
|
runner_log = env_runner.run(policy)
|
||||||
|
|
||||||
|
# dump log to json
|
||||||
|
json_log = dict()
|
||||||
|
for key, value in runner_log.items():
|
||||||
|
if isinstance(value, wandb.sdk.data_types.video.Video):
|
||||||
|
json_log[key] = value._path
|
||||||
|
else:
|
||||||
|
json_log[key] = value
|
||||||
|
out_path = os.path.join(output_dir, 'eval_log.json')
|
||||||
|
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user