170 lines
4.8 KiB
Python
170 lines
4.8 KiB
Python
|
import datetime
|
||
|
import os
|
||
|
import wandb
|
||
|
import yaml
|
||
|
from functools import wraps
|
||
|
|
||
|
|
||
|
def save_args(exp_dir, kwargs, filename='experiment_args.yml'):
|
||
|
filtered = {}
|
||
|
for key, value in kwargs.items():
|
||
|
if type(value) is tuple or type(value) is int or type(value) is float or type(value) is bool or type(
|
||
|
value) is str or value is None:
|
||
|
filtered[key] = value
|
||
|
with open(os.path.join(exp_dir, filename), 'w') as f:
|
||
|
yaml.safe_dump(filtered, f)
|
||
|
|
||
|
|
||
|
def save_module_args(exp_dir, args, filename='module_args.yml'):
|
||
|
save_args(exp_dir, args, filename=filename)
|
||
|
|
||
|
|
||
|
def load_args(exp_dir, filename='experiment_args.yml'):
|
||
|
with open(os.path.join(exp_dir, filename), 'r') as f:
|
||
|
args = yaml.safe_load(f)
|
||
|
return args
|
||
|
|
||
|
|
||
|
def load_module_args(exp_dir, filename='module.yml'):
|
||
|
return load_args(exp_dir, filename=filename)
|
||
|
|
||
|
|
||
|
def update_args(exp_dir, partial_args):
|
||
|
args = load_args(exp_dir)
|
||
|
for key, value in partial_args.items():
|
||
|
args[key] = value
|
||
|
save_args(exp_dir, args)
|
||
|
|
||
|
|
||
|
def evaluation(eval_func):
|
||
|
@wraps(eval_func)
|
||
|
def wrapper(**kwargs):
|
||
|
experiment_args = load_args(kwargs["exp_dir"])
|
||
|
# Run the experiment
|
||
|
eval_func(experiment_args, **kwargs)
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def filter_kwargs(kwargs, blacklist=['device', 'exp_dir']):
|
||
|
model_kwargs = {}
|
||
|
for key, value in kwargs.items():
|
||
|
if key not in blacklist and not key.endswith('_field'):
|
||
|
model_kwargs[key] = value
|
||
|
return model_kwargs
|
||
|
|
||
|
|
||
|
def pretrain_helper(model_load_function):
|
||
|
"""
|
||
|
Saves relevant model kwargs to a yml file (default is module.yml).
|
||
|
"""
|
||
|
|
||
|
@wraps(model_load_function)
|
||
|
def wrapper(**kwargs):
|
||
|
|
||
|
model_kwargs = filter_kwargs(kwargs)
|
||
|
|
||
|
# Inject submodels if any
|
||
|
submodule_kwargs = None
|
||
|
if "submodules" in kwargs:
|
||
|
submodule_kwargs = {}
|
||
|
for module_name, submodule in kwargs["submodules"].items():
|
||
|
kwargs[module_name] = submodule
|
||
|
submodule_kwargs[module_name] = filter_kwargs(submodule._all_kwargs)
|
||
|
model_kwargs['submodules'] = submodule_kwargs
|
||
|
|
||
|
save_module_args(kwargs['exp_dir'], model_kwargs)
|
||
|
|
||
|
# Run the experiment
|
||
|
model = model_load_function(**kwargs)
|
||
|
|
||
|
setattr(model, '_all_kwargs', kwargs)
|
||
|
|
||
|
return model
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def model_loader(model_load_function):
|
||
|
@wraps(model_load_function)
|
||
|
def wrapper(**kwargs):
|
||
|
# Inject submodels if any
|
||
|
if "submodules" in kwargs:
|
||
|
for module_name, submodule in kwargs["submodules"].items():
|
||
|
kwargs[module_name] = submodule
|
||
|
|
||
|
# Run the experiment
|
||
|
model = model_load_function(**kwargs)
|
||
|
|
||
|
# Save submodules in a dictionary (for saving, ...)
|
||
|
model.submodules = kwargs["submodules"] if "submodules" in kwargs else {}
|
||
|
|
||
|
return model
|
||
|
|
||
|
return wrapper
|
||
|
"""
|
||
|
|
||
|
def single_experiment(exp_func):
|
||
|
@wraps(exp_func)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
# Make results directory
|
||
|
assert 'results_dir' in kwargs and 'seed' in kwargs, "results_dir and seed must be arguments"
|
||
|
results_dir = os.path.join(kwargs['results_dir'], str(kwargs['seed']))
|
||
|
os.makedirs(results_dir, exist_ok=True)
|
||
|
kwargs['results_dir'] = results_dir
|
||
|
|
||
|
# Save arguments
|
||
|
save_args(results_dir, kwargs)
|
||
|
|
||
|
# Fix seed
|
||
|
fix_random_seed(kwargs['seed'])
|
||
|
|
||
|
# WandB
|
||
|
if kwargs['wandb_silent']:
|
||
|
os.environ["WANDB_SILENT"] = "true"
|
||
|
|
||
|
init = {"project": kwargs['project'],
|
||
|
"reinit": True,
|
||
|
"entity": kwargs['entity'],
|
||
|
"notes": datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'),
|
||
|
"config": kwargs}
|
||
|
|
||
|
if 'group' in kwargs and kwargs["group"] is not None:
|
||
|
init["group"] = kwargs['group']
|
||
|
if 'tags' in kwargs and kwargs["tags"] is not None:
|
||
|
init["tags"] = kwargs['tags']
|
||
|
if 'run_name' in kwargs and kwargs["run_name"] is not None:
|
||
|
init["name"] = kwargs['run_name']
|
||
|
|
||
|
run = wandb.init(**init)
|
||
|
|
||
|
# Run the experiment
|
||
|
exp_func(*args, **kwargs)
|
||
|
|
||
|
run.finish()
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def experiment(exp_func):
|
||
|
@wraps(exp_func)
|
||
|
def wrapper(opt):
|
||
|
# Make results directory
|
||
|
root_dir = opt.saving_root
|
||
|
opt.exp_dir = str(os.path.join(root_dir, opt.experiment_name, str(opt.seed)))
|
||
|
|
||
|
exists = os.path.exists(os.path.join(opt.exp_dir, 'checkpoints')) or \
|
||
|
os.path.exists(os.path.join(opt.exp_dir, 'summaries'))
|
||
|
|
||
|
if exists:
|
||
|
exit('Experiment already exists.')
|
||
|
os.makedirs(opt.exp_dir, exist_ok=True)
|
||
|
|
||
|
# Save arguments
|
||
|
save_args(opt.exp_dir, opt)
|
||
|
|
||
|
# Run the experiment
|
||
|
exp_func(opt)
|
||
|
|
||
|
return wrapper"""
|