146 lines
4.8 KiB
Python
146 lines
4.8 KiB
Python
from typing import Optional
|
|
import os
|
|
import pathlib
|
|
import hydra
|
|
import copy
|
|
from hydra.core.hydra_config import HydraConfig
|
|
from omegaconf import OmegaConf
|
|
import dill
|
|
import torch
|
|
import threading
|
|
|
|
|
|
class BaseWorkspace:
|
|
include_keys = tuple()
|
|
exclude_keys = tuple()
|
|
|
|
def __init__(self, cfg: OmegaConf, output_dir: Optional[str]=None):
|
|
self.cfg = cfg
|
|
self._output_dir = output_dir
|
|
self._saving_thread = None
|
|
|
|
@property
|
|
def output_dir(self):
|
|
output_dir = self._output_dir
|
|
if output_dir is None:
|
|
output_dir = HydraConfig.get().runtime.output_dir
|
|
return output_dir
|
|
|
|
def run(self):
|
|
"""
|
|
Create any resource shouldn't be serialized as local variables
|
|
"""
|
|
pass
|
|
|
|
def save_checkpoint(self, path=None, tag='latest',
|
|
exclude_keys=None,
|
|
include_keys=None,
|
|
use_thread=True):
|
|
if path is None:
|
|
path = pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt')
|
|
else:
|
|
path = pathlib.Path(path)
|
|
if exclude_keys is None:
|
|
exclude_keys = tuple(self.exclude_keys)
|
|
if include_keys is None:
|
|
include_keys = tuple(self.include_keys) + ('_output_dir',)
|
|
|
|
path.parent.mkdir(parents=False, exist_ok=True)
|
|
payload = {
|
|
'cfg': self.cfg,
|
|
'state_dicts': dict(),
|
|
'pickles': dict()
|
|
}
|
|
|
|
for key, value in self.__dict__.items():
|
|
if hasattr(value, 'state_dict') and hasattr(value, 'load_state_dict'):
|
|
# modules, optimizers and samplers etc
|
|
if key not in exclude_keys:
|
|
if use_thread:
|
|
payload['state_dicts'][key] = _copy_to_cpu(value.state_dict())
|
|
else:
|
|
payload['state_dicts'][key] = value.state_dict()
|
|
elif key in include_keys:
|
|
payload['pickles'][key] = dill.dumps(value)
|
|
if use_thread:
|
|
self._saving_thread = threading.Thread(
|
|
target=lambda : torch.save(payload, path.open('wb'), pickle_module=dill))
|
|
self._saving_thread.start()
|
|
else:
|
|
torch.save(payload, path.open('wb'), pickle_module=dill)
|
|
return str(path.absolute())
|
|
|
|
def get_checkpoint_path(self, tag='latest'):
|
|
return pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt')
|
|
|
|
def load_payload(self, payload, exclude_keys=None, include_keys=None, **kwargs):
|
|
if exclude_keys is None:
|
|
exclude_keys = tuple()
|
|
if include_keys is None:
|
|
include_keys = payload['pickles'].keys()
|
|
|
|
for key, value in payload['state_dicts'].items():
|
|
if key not in exclude_keys:
|
|
self.__dict__[key].load_state_dict(value, **kwargs)
|
|
for key in include_keys:
|
|
if key in payload['pickles']:
|
|
self.__dict__[key] = dill.loads(payload['pickles'][key])
|
|
|
|
def load_checkpoint(self, path=None, tag='latest',
|
|
exclude_keys=None,
|
|
include_keys=None,
|
|
**kwargs):
|
|
if path is None:
|
|
path = self.get_checkpoint_path(tag=tag)
|
|
else:
|
|
path = pathlib.Path(path)
|
|
payload = torch.load(path.open('rb'), pickle_module=dill, **kwargs)
|
|
self.load_payload(payload,
|
|
exclude_keys=exclude_keys,
|
|
include_keys=include_keys)
|
|
return payload
|
|
|
|
@classmethod
|
|
def create_from_checkpoint(cls, path,
|
|
exclude_keys=None,
|
|
include_keys=None,
|
|
**kwargs):
|
|
payload = torch.load(open(path, 'rb'), pickle_module=dill)
|
|
instance = cls(payload['cfg'])
|
|
instance.load_payload(
|
|
payload=payload,
|
|
exclude_keys=exclude_keys,
|
|
include_keys=include_keys,
|
|
**kwargs)
|
|
return instance
|
|
|
|
def save_snapshot(self, tag='latest'):
|
|
"""
|
|
Quick loading and saving for reserach, saves full state of the workspace.
|
|
|
|
However, loading a snapshot assumes the code stays exactly the same.
|
|
Use save_checkpoint for long-term storage.
|
|
"""
|
|
path = pathlib.Path(self.output_dir).joinpath('snapshots', f'{tag}.pkl')
|
|
path.parent.mkdir(parents=False, exist_ok=True)
|
|
torch.save(self, path.open('wb'), pickle_module=dill)
|
|
return str(path.absolute())
|
|
|
|
@classmethod
|
|
def create_from_snapshot(cls, path):
|
|
return torch.load(open(path, 'rb'), pickle_module=dill)
|
|
|
|
|
|
def _copy_to_cpu(x):
|
|
if isinstance(x, torch.Tensor):
|
|
return x.detach().to('cpu')
|
|
elif isinstance(x, dict):
|
|
result = dict()
|
|
for k, v in x.items():
|
|
result[k] = _copy_to_cpu(v)
|
|
return result
|
|
elif isinstance(x, list):
|
|
return [_copy_to_cpu(k) for k in x]
|
|
else:
|
|
return copy.deepcopy(x)
|