File size: 1,206 Bytes
0e37bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import random
import time
from typing import Any

import torch
from torch.serialization import MAP_LOCATION

from src.utils.logging import get_logger

logger = get_logger(os.path.basename(__file__))


def robust_checkpoint_loader(r_path: str, map_location: MAP_LOCATION = "cpu", max_retries: int = 3) -> Any:
    """
    Loads a checkpoint from a path, retrying up to max_retries times if the checkpoint is not found.
    """
    retries = 0

    while retries < max_retries:
        try:
            return torch.load(r_path, map_location=map_location)
        except Exception as e:
            logger.warning(f"Encountered exception when loading checkpoint {e}")
            retries += 1
            if retries < max_retries:
                sleep_time_s = (2**retries) * random.uniform(1.0, 1.1)
                logger.warning(f"Sleeping {sleep_time_s}s and trying again, count {retries}/{max_retries}")
                time.sleep(sleep_time_s)
                continue
            else:
                raise e