mpd-public/mpd/utils/decorators.py
2023-10-23 15:45:14 +02:00

170 lines
4.8 KiB
Python

import datetime
import os
import wandb
import yaml
from functools import wraps
def save_args(exp_dir, kwargs, filename='experiment_args.yml'):
filtered = {}
for key, value in kwargs.items():
if type(value) is tuple or type(value) is int or type(value) is float or type(value) is bool or type(
value) is str or value is None:
filtered[key] = value
with open(os.path.join(exp_dir, filename), 'w') as f:
yaml.safe_dump(filtered, f)
def save_module_args(exp_dir, args, filename='module_args.yml'):
save_args(exp_dir, args, filename=filename)
def load_args(exp_dir, filename='experiment_args.yml'):
with open(os.path.join(exp_dir, filename), 'r') as f:
args = yaml.safe_load(f)
return args
def load_module_args(exp_dir, filename='module.yml'):
return load_args(exp_dir, filename=filename)
def update_args(exp_dir, partial_args):
args = load_args(exp_dir)
for key, value in partial_args.items():
args[key] = value
save_args(exp_dir, args)
def evaluation(eval_func):
@wraps(eval_func)
def wrapper(**kwargs):
experiment_args = load_args(kwargs["exp_dir"])
# Run the experiment
eval_func(experiment_args, **kwargs)
return wrapper
def filter_kwargs(kwargs, blacklist=['device', 'exp_dir']):
model_kwargs = {}
for key, value in kwargs.items():
if key not in blacklist and not key.endswith('_field'):
model_kwargs[key] = value
return model_kwargs
def pretrain_helper(model_load_function):
"""
Saves relevant model kwargs to a yml file (default is module.yml).
"""
@wraps(model_load_function)
def wrapper(**kwargs):
model_kwargs = filter_kwargs(kwargs)
# Inject submodels if any
submodule_kwargs = None
if "submodules" in kwargs:
submodule_kwargs = {}
for module_name, submodule in kwargs["submodules"].items():
kwargs[module_name] = submodule
submodule_kwargs[module_name] = filter_kwargs(submodule._all_kwargs)
model_kwargs['submodules'] = submodule_kwargs
save_module_args(kwargs['exp_dir'], model_kwargs)
# Run the experiment
model = model_load_function(**kwargs)
setattr(model, '_all_kwargs', kwargs)
return model
return wrapper
def model_loader(model_load_function):
@wraps(model_load_function)
def wrapper(**kwargs):
# Inject submodels if any
if "submodules" in kwargs:
for module_name, submodule in kwargs["submodules"].items():
kwargs[module_name] = submodule
# Run the experiment
model = model_load_function(**kwargs)
# Save submodules in a dictionary (for saving, ...)
model.submodules = kwargs["submodules"] if "submodules" in kwargs else {}
return model
return wrapper
"""
def single_experiment(exp_func):
@wraps(exp_func)
def wrapper(*args, **kwargs):
# Make results directory
assert 'results_dir' in kwargs and 'seed' in kwargs, "results_dir and seed must be arguments"
results_dir = os.path.join(kwargs['results_dir'], str(kwargs['seed']))
os.makedirs(results_dir, exist_ok=True)
kwargs['results_dir'] = results_dir
# Save arguments
save_args(results_dir, kwargs)
# Fix seed
fix_random_seed(kwargs['seed'])
# WandB
if kwargs['wandb_silent']:
os.environ["WANDB_SILENT"] = "true"
init = {"project": kwargs['project'],
"reinit": True,
"entity": kwargs['entity'],
"notes": datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'),
"config": kwargs}
if 'group' in kwargs and kwargs["group"] is not None:
init["group"] = kwargs['group']
if 'tags' in kwargs and kwargs["tags"] is not None:
init["tags"] = kwargs['tags']
if 'run_name' in kwargs and kwargs["run_name"] is not None:
init["name"] = kwargs['run_name']
run = wandb.init(**init)
# Run the experiment
exp_func(*args, **kwargs)
run.finish()
return wrapper
def experiment(exp_func):
@wraps(exp_func)
def wrapper(opt):
# Make results directory
root_dir = opt.saving_root
opt.exp_dir = str(os.path.join(root_dir, opt.experiment_name, str(opt.seed)))
exists = os.path.exists(os.path.join(opt.exp_dir, 'checkpoints')) or \
os.path.exists(os.path.join(opt.exp_dir, 'summaries'))
if exists:
exit('Experiment already exists.')
os.makedirs(opt.exp_dir, exist_ok=True)
# Save arguments
save_args(opt.exp_dir, opt)
# Run the experiment
exp_func(opt)
return wrapper"""