File size: 2,111 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import hashlib
from filelock import FileLock
import torch
import gdown
def _download(filename, url, refresh, agent):
dirpath = f'{torch.hub.get_dir()}/s3prl_cache'
os.makedirs(dirpath, exist_ok=True)
filepath = f'{dirpath}/{filename}'
with FileLock(filepath + ".lock"):
if not os.path.isfile(filepath) or refresh:
if agent == 'wget':
os.system(f'wget {url} -O {filepath}')
elif agent == 'gdown':
gdown.download(url, filepath, use_cookies=False)
else:
print('[Download] - Unknown download agent. Only \'wget\' and \'gdown\' are supported.')
raise NotImplementedError
else:
print(f'Using cache found in {filepath}\nfor {url}')
return filepath
def _urls_to_filepaths(*args, refresh=False, agent='wget'):
"""
Preprocess the URL specified in *args into local file paths after downloading
Args:
Any number of URLs (1 ~ any)
Return:
Same number of downloaded file paths
"""
def url_to_filename(url):
assert type(url) is str
m = hashlib.sha256()
m.update(str.encode(url))
return str(m.hexdigest())
def url_to_path(url, refresh):
if type(url) is str and len(url) > 0:
return _download(url_to_filename(url), url, refresh, agent=agent)
else:
return None
paths = [url_to_path(url, refresh) for url in args]
return paths if len(paths) > 1 else paths[0]
def _gdriveids_to_filepaths(*args, refresh=False):
"""
Preprocess the Google Drive id specified in *args into local file paths after downloading
Args:
Any number of Google Drive ids (1 ~ any)
Return:
Same number of downloaded file paths
"""
def gdriveid_to_url(gdriveid):
if type(gdriveid) is str and len(gdriveid) > 0:
return f'https://drive.google.com/uc?id={gdriveid}'
else:
return None
return _urls_to_filepaths(*[gdriveid_to_url(gid) for gid in args], refresh=refresh, agent='gdown')
|