""" Start local ray cluster (robodiff)$ export CUDA_VISIBLE_DEVICES=0,1,2 # select GPUs to be managed by the ray cluster (robodiff)$ ray start --head --num-gpus=3 Training: python ray_train_multirun.py --config-name=train_diffusion_unet_lowdim_workspace --seeds=42,43,44 --monitor_key=test/mean_score -- logger.mode=online training.eval_first=True """ import os import ray import click import hydra import yaml import wandb import pathlib import collections from pprint import pprint from omegaconf import OmegaConf from ray_exec import worker_fn from ray.util.placement_group import ( placement_group, ) from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy OmegaConf.register_new_resolver("eval", eval, replace=True) @click.command() @click.option('--config-name', '-cn', required=True, type=str) @click.option('--config-dir', '-cd', default=None, type=str) @click.option('--seeds', '-s', default='42,43,44', type=str) @click.option('--monitor_key', '-k', multiple=True, default=['test/mean_score']) @click.option('--ray_address', '-ra', default='auto') @click.option('--num_cpus', '-nc', default=7, type=float) @click.option('--num_gpus', '-ng', default=1, type=float) @click.option('--max_retries', '-mr', default=0, type=int) @click.option('--monitor_max_retires', default=3, type=int) @click.option('--data_src', '-d', default='./data', type=str) @click.option('--unbuffer_python', '-u', is_flag=True, default=False) @click.option('--single_node', '-sn', is_flag=True, default=False, help='run all experiments on a single machine') @click.argument('command_args', nargs=-1, type=str) def main(config_name, config_dir, seeds, monitor_key, ray_address, num_cpus, num_gpus, max_retries, monitor_max_retires, data_src, unbuffer_python, single_node, command_args): # parse args seeds = [int(x) for x in seeds.split(',')] # expand path if data_src is not None: data_src = os.path.abspath(os.path.expanduser(data_src)) # initialize hydra if config_dir is None: config_path_abs = pathlib.Path(__file__).parent.joinpath( 'diffusion_policy','config') config_path_rel = str(config_path_abs.relative_to(pathlib.Path.cwd())) else: config_path_rel = config_dir run_command_args = list() monitor_command_args = list() with hydra.initialize( version_base=None, config_path=config_path_rel): # generate raw config cfg = hydra.compose( config_name=config_name, overrides=command_args) OmegaConf.resolve(cfg) # manually create output dir output_dir = pathlib.Path(cfg.multi_run.run_dir) output_dir.mkdir(parents=True, exist_ok=False) config_path = output_dir.joinpath('config.yaml') print(output_dir) # save current config yaml.dump(OmegaConf.to_container(cfg, resolve=True), config_path.open('w'), default_flow_style=False) # wandb wandb_group_id = wandb.util.generate_id() name_base = cfg.multi_run.wandb_name_base # create monitor command args monitor_command_args = [ 'python', 'multirun_metrics.py', '--input', str(output_dir), '--use_wandb', '--project', 'diffusion_policy_metrics', '--group', wandb_group_id ] for k in monitor_key: monitor_command_args.extend([ '--key', k ]) # generate command args run_command_args = list() for i, seed in enumerate(seeds): test_start_seed = (seed + 1) * 100000 this_output_dir = output_dir.joinpath(f'train_{i}') this_output_dir.mkdir() wandb_name = name_base + f'_train_{i}' wandb_run_id = wandb_group_id + f'_train_{i}' this_command_args = [ 'python', 'train.py', '--config-name='+config_name, '--config-dir='+config_path_rel ] this_command_args.extend(command_args) this_command_args.extend([ f'training.seed={seed}', f'task.env_runner.test_start_seed={test_start_seed}', f'logging.name={wandb_name}', f'logging.id={wandb_run_id}', f'logging.group={wandb_group_id}', f'hydra.run.dir={this_output_dir}' ]) run_command_args.append(this_command_args) # init ray root_dir = os.path.dirname(__file__) runtime_env = { 'working_dir': root_dir, 'excludes': ['.git'], 'pip': ['dm-control==1.0.9'] } ray.init( address=ray_address, runtime_env=runtime_env ) # create resources for train train_resources = dict() train_bundle = dict(train_resources) train_bundle['CPU'] = num_cpus train_bundle['GPU'] = num_gpus # create resources for monitor monitor_resources = dict() monitor_resources['CPU'] = 1 monitor_bundle = dict(monitor_resources) # aggregate bundle bundle = collections.defaultdict(lambda:0) n_train_bundles = 1 if single_node: n_train_bundles = len(seeds) for _ in range(n_train_bundles): for k, v in train_bundle.items(): bundle[k] += v for k, v in monitor_bundle.items(): bundle[k] += v bundle = dict(bundle) # create placement group print("Creating placement group with resources:") pprint(bundle) pg = placement_group([bundle]) # run task_name_map = dict() task_refs = list() for i, this_command_args in enumerate(run_command_args): if single_node or i == (len(run_command_args) - 1): print(f'Training worker {i} with placement group.') ray.get(pg.ready()) print("Placement Group created!") worker_ray = ray.remote(worker_fn).options( num_cpus=num_cpus, num_gpus=num_gpus, max_retries=max_retries, resources=train_resources, retry_exceptions=True, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg) ) else: print(f'Training worker {i} without placement group.') worker_ray = ray.remote(worker_fn).options( num_cpus=num_cpus, num_gpus=num_gpus, max_retries=max_retries, resources=train_resources, retry_exceptions=True, ) task_ref = worker_ray.remote( this_command_args, data_src, unbuffer_python) task_refs.append(task_ref) task_name_map[task_ref] = f'train_{i}' # monitor worker is always packed on the same node # as training worker 0 ray.get(pg.ready()) monitor_worker_ray = ray.remote(worker_fn).options( num_cpus=1, num_gpus=0, max_retries=monitor_max_retires, # resources=monitor_resources, retry_exceptions=True, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg) ) monitor_ref = monitor_worker_ray.remote( monitor_command_args, data_src, unbuffer_python) task_name_map[monitor_ref] = 'metrics' try: # normal case ready_refs = list() rest_refs = task_refs while len(ready_refs) < len(task_refs): this_ready_refs, rest_refs = ray.wait(rest_refs, num_returns=1, timeout=None, fetch_local=True) cancel_other_tasks = False for ref in this_ready_refs: task_name = task_name_map[ref] try: result = ray.get(ref) print(f"Task {task_name} finished with result: {result}") except KeyboardInterrupt as e: # skip to outer try catch raise KeyboardInterrupt except Exception as e: print(f"Task {task_name} raised exception: {e}") this_cancel_other_tasks = True if isinstance(e, ray.exceptions.RayTaskError): if isinstance(e.cause, ray.exceptions.TaskCancelledError): this_cancel_other_tasks = False cancel_other_tasks = cancel_other_tasks or this_cancel_other_tasks ready_refs.append(ref) if cancel_other_tasks: print('Exception! Cancelling all other tasks.') # cancel all other refs for _ref in rest_refs: ray.cancel(_ref, force=False) print("Training tasks done.") ray.cancel(monitor_ref, force=False) except KeyboardInterrupt: print('KeyboardInterrupt received in the driver.') # a KeyboardInterrupt will be raised in worker _ = [ray.cancel(x, force=False) for x in task_refs + [monitor_ref]] print('KeyboardInterrupt sent to workers.') except Exception as e: # worker will be terminated _ = [ray.cancel(x, force=True) for x in task_refs + [monitor_ref]] raise e for ref in task_refs + [monitor_ref]: task_name = task_name_map[ref] try: result = ray.get(ref) print(f"Task {task_name} finished with result: {result}") except KeyboardInterrupt as e: # force kill everything. print("Force killing all workers") _ = [ray.cancel(x, force=True) for x in task_refs] ray.cancel(monitor_ref, force=True) except Exception as e: print(f"Task {task_name} raised exception: {e}") if __name__ == "__main__": main()