272 lines
9.6 KiB
Python
272 lines
9.6 KiB
Python
|
"""
|
||
|
Start local ray cluster
|
||
|
(robodiff)$ export CUDA_VISIBLE_DEVICES=0,1,2 # select GPUs to be managed by the ray cluster
|
||
|
(robodiff)$ ray start --head --num-gpus=3
|
||
|
|
||
|
Training:
|
||
|
python ray_train_multirun.py --config-name=train_diffusion_unet_lowdim_workspace --seeds=42,43,44 --monitor_key=test/mean_score -- logger.mode=online training.eval_first=True
|
||
|
"""
|
||
|
import os
|
||
|
import ray
|
||
|
import click
|
||
|
import hydra
|
||
|
import yaml
|
||
|
import wandb
|
||
|
import pathlib
|
||
|
import collections
|
||
|
from pprint import pprint
|
||
|
from omegaconf import OmegaConf
|
||
|
from ray_exec import worker_fn
|
||
|
from ray.util.placement_group import (
|
||
|
placement_group,
|
||
|
)
|
||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||
|
|
||
|
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
||
|
|
||
|
@click.command()
|
||
|
@click.option('--config-name', '-cn', required=True, type=str)
|
||
|
@click.option('--config-dir', '-cd', default=None, type=str)
|
||
|
@click.option('--seeds', '-s', default='42,43,44', type=str)
|
||
|
@click.option('--monitor_key', '-k', multiple=True, default=['test/mean_score'])
|
||
|
@click.option('--ray_address', '-ra', default='auto')
|
||
|
@click.option('--num_cpus', '-nc', default=7, type=float)
|
||
|
@click.option('--num_gpus', '-ng', default=1, type=float)
|
||
|
@click.option('--max_retries', '-mr', default=0, type=int)
|
||
|
@click.option('--monitor_max_retires', default=3, type=int)
|
||
|
@click.option('--data_src', '-d', default='./data', type=str)
|
||
|
@click.option('--unbuffer_python', '-u', is_flag=True, default=False)
|
||
|
@click.option('--single_node', '-sn', is_flag=True, default=False, help='run all experiments on a single machine')
|
||
|
@click.argument('command_args', nargs=-1, type=str)
|
||
|
def main(config_name, config_dir, seeds, monitor_key, ray_address,
|
||
|
num_cpus, num_gpus, max_retries, monitor_max_retires,
|
||
|
data_src, unbuffer_python,
|
||
|
single_node, command_args):
|
||
|
# parse args
|
||
|
seeds = [int(x) for x in seeds.split(',')]
|
||
|
# expand path
|
||
|
if data_src is not None:
|
||
|
data_src = os.path.abspath(os.path.expanduser(data_src))
|
||
|
|
||
|
# initialize hydra
|
||
|
if config_dir is None:
|
||
|
config_path_abs = pathlib.Path(__file__).parent.joinpath(
|
||
|
'diffusion_policy','config')
|
||
|
config_path_rel = str(config_path_abs.relative_to(pathlib.Path.cwd()))
|
||
|
else:
|
||
|
config_path_rel = config_dir
|
||
|
|
||
|
run_command_args = list()
|
||
|
monitor_command_args = list()
|
||
|
with hydra.initialize(
|
||
|
version_base=None,
|
||
|
config_path=config_path_rel):
|
||
|
|
||
|
# generate raw config
|
||
|
cfg = hydra.compose(
|
||
|
config_name=config_name,
|
||
|
overrides=command_args)
|
||
|
OmegaConf.resolve(cfg)
|
||
|
|
||
|
# manually create output dir
|
||
|
output_dir = pathlib.Path(cfg.multi_run.run_dir)
|
||
|
output_dir.mkdir(parents=True, exist_ok=False)
|
||
|
config_path = output_dir.joinpath('config.yaml')
|
||
|
print(output_dir)
|
||
|
|
||
|
# save current config
|
||
|
yaml.dump(OmegaConf.to_container(cfg, resolve=True),
|
||
|
config_path.open('w'), default_flow_style=False)
|
||
|
|
||
|
# wandb
|
||
|
wandb_group_id = wandb.util.generate_id()
|
||
|
name_base = cfg.multi_run.wandb_name_base
|
||
|
|
||
|
# create monitor command args
|
||
|
monitor_command_args = [
|
||
|
'python',
|
||
|
'multirun_metrics.py',
|
||
|
'--input', str(output_dir),
|
||
|
'--use_wandb',
|
||
|
'--project', 'diffusion_policy_metrics',
|
||
|
'--group', wandb_group_id
|
||
|
]
|
||
|
for k in monitor_key:
|
||
|
monitor_command_args.extend([
|
||
|
'--key', k
|
||
|
])
|
||
|
|
||
|
# generate command args
|
||
|
run_command_args = list()
|
||
|
for i, seed in enumerate(seeds):
|
||
|
test_start_seed = (seed + 1) * 100000
|
||
|
this_output_dir = output_dir.joinpath(f'train_{i}')
|
||
|
this_output_dir.mkdir()
|
||
|
wandb_name = name_base + f'_train_{i}'
|
||
|
wandb_run_id = wandb_group_id + f'_train_{i}'
|
||
|
|
||
|
this_command_args = [
|
||
|
'python',
|
||
|
'train.py',
|
||
|
'--config-name='+config_name,
|
||
|
'--config-dir='+config_path_rel
|
||
|
]
|
||
|
|
||
|
this_command_args.extend(command_args)
|
||
|
this_command_args.extend([
|
||
|
f'training.seed={seed}',
|
||
|
f'task.env_runner.test_start_seed={test_start_seed}',
|
||
|
f'logging.name={wandb_name}',
|
||
|
f'logging.id={wandb_run_id}',
|
||
|
f'logging.group={wandb_group_id}',
|
||
|
f'hydra.run.dir={this_output_dir}'
|
||
|
])
|
||
|
run_command_args.append(this_command_args)
|
||
|
|
||
|
# init ray
|
||
|
root_dir = os.path.dirname(__file__)
|
||
|
runtime_env = {
|
||
|
'working_dir': root_dir,
|
||
|
'excludes': ['.git'],
|
||
|
'pip': ['dm-control==1.0.9']
|
||
|
}
|
||
|
ray.init(
|
||
|
address=ray_address,
|
||
|
runtime_env=runtime_env
|
||
|
)
|
||
|
|
||
|
# create resources for train
|
||
|
train_resources = dict()
|
||
|
|
||
|
train_bundle = dict(train_resources)
|
||
|
train_bundle['CPU'] = num_cpus
|
||
|
train_bundle['GPU'] = num_gpus
|
||
|
|
||
|
# create resources for monitor
|
||
|
monitor_resources = dict()
|
||
|
monitor_resources['CPU'] = 1
|
||
|
|
||
|
monitor_bundle = dict(monitor_resources)
|
||
|
|
||
|
# aggregate bundle
|
||
|
bundle = collections.defaultdict(lambda:0)
|
||
|
n_train_bundles = 1
|
||
|
if single_node:
|
||
|
n_train_bundles = len(seeds)
|
||
|
for _ in range(n_train_bundles):
|
||
|
for k, v in train_bundle.items():
|
||
|
bundle[k] += v
|
||
|
for k, v in monitor_bundle.items():
|
||
|
bundle[k] += v
|
||
|
bundle = dict(bundle)
|
||
|
|
||
|
# create placement group
|
||
|
print("Creating placement group with resources:")
|
||
|
pprint(bundle)
|
||
|
pg = placement_group([bundle])
|
||
|
|
||
|
# run
|
||
|
task_name_map = dict()
|
||
|
task_refs = list()
|
||
|
for i, this_command_args in enumerate(run_command_args):
|
||
|
if single_node or i == (len(run_command_args) - 1):
|
||
|
print(f'Training worker {i} with placement group.')
|
||
|
ray.get(pg.ready())
|
||
|
print("Placement Group created!")
|
||
|
worker_ray = ray.remote(worker_fn).options(
|
||
|
num_cpus=num_cpus,
|
||
|
num_gpus=num_gpus,
|
||
|
max_retries=max_retries,
|
||
|
resources=train_resources,
|
||
|
retry_exceptions=True,
|
||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||
|
placement_group=pg)
|
||
|
)
|
||
|
else:
|
||
|
print(f'Training worker {i} without placement group.')
|
||
|
worker_ray = ray.remote(worker_fn).options(
|
||
|
num_cpus=num_cpus,
|
||
|
num_gpus=num_gpus,
|
||
|
max_retries=max_retries,
|
||
|
resources=train_resources,
|
||
|
retry_exceptions=True,
|
||
|
)
|
||
|
task_ref = worker_ray.remote(
|
||
|
this_command_args, data_src, unbuffer_python)
|
||
|
task_refs.append(task_ref)
|
||
|
task_name_map[task_ref] = f'train_{i}'
|
||
|
|
||
|
# monitor worker is always packed on the same node
|
||
|
# as training worker 0
|
||
|
ray.get(pg.ready())
|
||
|
monitor_worker_ray = ray.remote(worker_fn).options(
|
||
|
num_cpus=1,
|
||
|
num_gpus=0,
|
||
|
max_retries=monitor_max_retires,
|
||
|
# resources=monitor_resources,
|
||
|
retry_exceptions=True,
|
||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||
|
placement_group=pg)
|
||
|
)
|
||
|
monitor_ref = monitor_worker_ray.remote(
|
||
|
monitor_command_args, data_src, unbuffer_python)
|
||
|
task_name_map[monitor_ref] = 'metrics'
|
||
|
|
||
|
try:
|
||
|
# normal case
|
||
|
ready_refs = list()
|
||
|
rest_refs = task_refs
|
||
|
while len(ready_refs) < len(task_refs):
|
||
|
this_ready_refs, rest_refs = ray.wait(rest_refs,
|
||
|
num_returns=1, timeout=None, fetch_local=True)
|
||
|
cancel_other_tasks = False
|
||
|
for ref in this_ready_refs:
|
||
|
task_name = task_name_map[ref]
|
||
|
try:
|
||
|
result = ray.get(ref)
|
||
|
print(f"Task {task_name} finished with result: {result}")
|
||
|
except KeyboardInterrupt as e:
|
||
|
# skip to outer try catch
|
||
|
raise KeyboardInterrupt
|
||
|
except Exception as e:
|
||
|
print(f"Task {task_name} raised exception: {e}")
|
||
|
this_cancel_other_tasks = True
|
||
|
if isinstance(e, ray.exceptions.RayTaskError):
|
||
|
if isinstance(e.cause, ray.exceptions.TaskCancelledError):
|
||
|
this_cancel_other_tasks = False
|
||
|
cancel_other_tasks = cancel_other_tasks or this_cancel_other_tasks
|
||
|
ready_refs.append(ref)
|
||
|
if cancel_other_tasks:
|
||
|
print('Exception! Cancelling all other tasks.')
|
||
|
# cancel all other refs
|
||
|
for _ref in rest_refs:
|
||
|
ray.cancel(_ref, force=False)
|
||
|
print("Training tasks done.")
|
||
|
ray.cancel(monitor_ref, force=False)
|
||
|
except KeyboardInterrupt:
|
||
|
print('KeyboardInterrupt received in the driver.')
|
||
|
# a KeyboardInterrupt will be raised in worker
|
||
|
_ = [ray.cancel(x, force=False) for x in task_refs + [monitor_ref]]
|
||
|
print('KeyboardInterrupt sent to workers.')
|
||
|
except Exception as e:
|
||
|
# worker will be terminated
|
||
|
_ = [ray.cancel(x, force=True) for x in task_refs + [monitor_ref]]
|
||
|
raise e
|
||
|
|
||
|
for ref in task_refs + [monitor_ref]:
|
||
|
task_name = task_name_map[ref]
|
||
|
try:
|
||
|
result = ray.get(ref)
|
||
|
print(f"Task {task_name} finished with result: {result}")
|
||
|
except KeyboardInterrupt as e:
|
||
|
# force kill everything.
|
||
|
print("Force killing all workers")
|
||
|
_ = [ray.cancel(x, force=True) for x in task_refs]
|
||
|
ray.cancel(monitor_ref, force=True)
|
||
|
except Exception as e:
|
||
|
print(f"Task {task_name} raised exception: {e}")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|