2023-03-07 16:07:15 -05:00
|
|
|
"""
|
|
|
|
Usage:
|
|
|
|
Training:
|
|
|
|
python train.py --config-name=train_diffusion_lowdim_workspace -- logger.mode=online
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
import ray
|
|
|
|
import click
|
|
|
|
|
|
|
|
def worker_fn(command_args, data_src=None, unbuffer_python=False, use_shell=False):
|
|
|
|
import os
|
|
|
|
import subprocess
|
|
|
|
import signal
|
|
|
|
import time
|
|
|
|
|
|
|
|
# setup data symlink
|
|
|
|
if data_src is not None:
|
|
|
|
cwd = os.getcwd()
|
|
|
|
src = os.path.expanduser(data_src)
|
|
|
|
dst = os.path.join(cwd, 'data')
|
|
|
|
try:
|
|
|
|
os.symlink(src=src, dst=dst)
|
|
|
|
except FileExistsError:
|
|
|
|
# it's fine if it already exists
|
|
|
|
pass
|
|
|
|
|
|
|
|
# run command
|
|
|
|
process_env = os.environ.copy()
|
|
|
|
if unbuffer_python:
|
|
|
|
# disable stdout/stderr buffering for subprocess (if python)
|
|
|
|
# to remove latency between print statement and receiving printed result
|
|
|
|
process_env['PYTHONUNBUFFERED'] = 'TRUE'
|
|
|
|
|
|
|
|
# ray worker masks out Ctrl-C signal (ie SIGINT)
|
|
|
|
# here we unblock this signal for the child process
|
|
|
|
def preexec_function():
|
|
|
|
import signal
|
|
|
|
signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGINT})
|
|
|
|
|
|
|
|
if use_shell:
|
|
|
|
command_args = ' '.join(command_args)
|
|
|
|
|
2023-04-05 16:03:18 -04:00
|
|
|
# stdout passthrough to ray worker, which is then passed to ray driver
|
2023-03-07 16:07:15 -05:00
|
|
|
process = subprocess.Popen(
|
|
|
|
args=command_args,
|
|
|
|
env=process_env,
|
|
|
|
preexec_fn=preexec_function,
|
|
|
|
shell=use_shell)
|
|
|
|
|
|
|
|
while process.poll() is None:
|
|
|
|
try:
|
|
|
|
# sleep to ensure that monitor thread can acquire gil
|
|
|
|
# and raise KeyboardInterrupt here.
|
|
|
|
time.sleep(0.01)
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
process.send_signal(signal.SIGINT)
|
|
|
|
print('SIGINT sent to subprocess')
|
|
|
|
except Exception as e:
|
|
|
|
process.terminate()
|
|
|
|
raise e
|
|
|
|
|
|
|
|
if process.returncode not in (0, -2):
|
|
|
|
print("Failed execution!")
|
|
|
|
raise RuntimeError("Failed execution.")
|
|
|
|
return process.returncode
|
|
|
|
|
|
|
|
|
|
|
|
@click.command()
|
|
|
|
@click.option('--ray_address', '-ra', default='auto')
|
|
|
|
@click.option('--num_cpus', '-nc', default=7, type=float)
|
|
|
|
@click.option('--num_gpus', '-ng', default=1, type=float)
|
|
|
|
@click.option('--max_retries', '-mr', default=0, type=int)
|
|
|
|
@click.option('--data_src', '-d', default='./data', type=str)
|
|
|
|
@click.option('--unbuffer_python', '-u', is_flag=True, default=False)
|
|
|
|
@click.argument('command_args', nargs=-1, type=str)
|
|
|
|
def main(ray_address,
|
|
|
|
num_cpus, num_gpus, max_retries,
|
|
|
|
data_src, unbuffer_python,
|
|
|
|
command_args):
|
|
|
|
# expand path
|
|
|
|
if data_src is not None:
|
|
|
|
data_src = os.path.abspath(os.path.expanduser(data_src))
|
|
|
|
|
|
|
|
# init ray
|
|
|
|
root_dir = os.path.dirname(__file__)
|
|
|
|
runtime_env = {
|
|
|
|
'working_dir': root_dir,
|
|
|
|
'excludes': ['.git']
|
|
|
|
}
|
|
|
|
ray.init(
|
|
|
|
address=ray_address,
|
|
|
|
runtime_env=runtime_env
|
|
|
|
)
|
|
|
|
# remote worker func
|
|
|
|
worker_ray = ray.remote(worker_fn).options(
|
|
|
|
num_cpus=num_cpus,
|
|
|
|
num_gpus=num_gpus,
|
|
|
|
max_retries=max_retries,
|
|
|
|
# resources=resources,
|
|
|
|
retry_exceptions=True
|
|
|
|
)
|
|
|
|
# run
|
|
|
|
task_ref = worker_ray.remote(command_args, data_src, unbuffer_python)
|
|
|
|
|
|
|
|
try:
|
|
|
|
# normal case
|
|
|
|
result = ray.get(task_ref)
|
|
|
|
print('Return code: ', result)
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
# a KeyboardInterrupt will be raised in worker
|
|
|
|
ray.cancel(task_ref, force=False)
|
|
|
|
result = ray.get(task_ref)
|
|
|
|
print('Return code: ', result)
|
|
|
|
except Exception as e:
|
|
|
|
# worker will be terminated
|
|
|
|
ray.cancel(task_ref, force=True)
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|