from typing import Optional import os import pathlib import hydra import copy from hydra.core.hydra_config import HydraConfig from omegaconf import OmegaConf import dill import torch import threading class BaseWorkspace: include_keys = tuple() exclude_keys = tuple() def __init__(self, cfg: OmegaConf, output_dir: Optional[str]=None): self.cfg = cfg self._output_dir = output_dir self._saving_thread = None @property def output_dir(self): output_dir = self._output_dir if output_dir is None: output_dir = HydraConfig.get().runtime.output_dir return output_dir def run(self): """ Create any resource shouldn't be serialized as local variables """ pass def save_checkpoint(self, path=None, tag='latest', exclude_keys=None, include_keys=None, use_thread=True): if path is None: path = pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt') else: path = pathlib.Path(path) if exclude_keys is None: exclude_keys = tuple(self.exclude_keys) if include_keys is None: include_keys = tuple(self.include_keys) + ('_output_dir',) path.parent.mkdir(parents=False, exist_ok=True) payload = { 'cfg': self.cfg, 'state_dicts': dict(), 'pickles': dict() } for key, value in self.__dict__.items(): if hasattr(value, 'state_dict') and hasattr(value, 'load_state_dict'): # modules, optimizers and samplers etc if key not in exclude_keys: if use_thread: payload['state_dicts'][key] = _copy_to_cpu(value.state_dict()) else: payload['state_dicts'][key] = value.state_dict() elif key in include_keys: payload['pickles'][key] = dill.dumps(value) if use_thread: self._saving_thread = threading.Thread( target=lambda : torch.save(payload, path.open('wb'), pickle_module=dill)) self._saving_thread.start() else: torch.save(payload, path.open('wb'), pickle_module=dill) return str(path.absolute()) def get_checkpoint_path(self, tag='latest'): return pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt') def load_payload(self, payload, exclude_keys=None, include_keys=None, **kwargs): if exclude_keys is None: exclude_keys = tuple() if include_keys is None: include_keys = payload['pickles'].keys() for key, value in payload['state_dicts'].items(): if key not in exclude_keys: self.__dict__[key].load_state_dict(value, **kwargs) for key in include_keys: if key in payload['pickles']: self.__dict__[key] = dill.loads(payload['pickles'][key]) def load_checkpoint(self, path=None, tag='latest', exclude_keys=None, include_keys=None, **kwargs): if path is None: path = self.get_checkpoint_path(tag=tag) else: path = pathlib.Path(path) payload = torch.load(path.open('rb'), pickle_module=dill, **kwargs) self.load_payload(payload, exclude_keys=exclude_keys, include_keys=include_keys) return payload @classmethod def create_from_checkpoint(cls, path, exclude_keys=None, include_keys=None, **kwargs): payload = torch.load(open(path, 'rb'), pickle_module=dill) instance = cls(payload['cfg']) instance.load_payload( payload=payload, exclude_keys=exclude_keys, include_keys=include_keys, **kwargs) return instance def save_snapshot(self, tag='latest'): """ Quick loading and saving for reserach, saves full state of the workspace. However, loading a snapshot assumes the code stays exactly the same. Use save_checkpoint for long-term storage. """ path = pathlib.Path(self.output_dir).joinpath('snapshots', f'{tag}.pkl') path.parent.mkdir(parents=False, exist_ok=True) torch.save(self, path.open('wb'), pickle_module=dill) return str(path.absolute()) @classmethod def create_from_snapshot(cls, path): return torch.load(open(path, 'rb'), pickle_module=dill) def _copy_to_cpu(x): if isinstance(x, torch.Tensor): return x.detach().to('cpu') elif isinstance(x, dict): result = dict() for k, v in x.items(): result[k] = _copy_to_cpu(v) return result elif isinstance(x, list): return [_copy_to_cpu(k) for k in x] else: return copy.deepcopy(x)