File size: 1,206 Bytes
0e37bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import random
import time
from typing import Any
import torch
from torch.serialization import MAP_LOCATION
from src.utils.logging import get_logger
logger = get_logger(os.path.basename(__file__))
def robust_checkpoint_loader(r_path: str, map_location: MAP_LOCATION = "cpu", max_retries: int = 3) -> Any:
"""
Loads a checkpoint from a path, retrying up to max_retries times if the checkpoint is not found.
"""
retries = 0
while retries < max_retries:
try:
return torch.load(r_path, map_location=map_location)
except Exception as e:
logger.warning(f"Encountered exception when loading checkpoint {e}")
retries += 1
if retries < max_retries:
sleep_time_s = (2**retries) * random.uniform(1.0, 1.1)
logger.warning(f"Sleeping {sleep_time_s}s and trying again, count {retries}/{max_retries}")
time.sleep(sleep_time_s)
continue
else:
raise e
|