|
import os |
|
import hashlib |
|
from filelock import FileLock |
|
|
|
import torch |
|
import gdown |
|
|
|
|
|
def _download(filename, url, refresh, agent): |
|
dirpath = f'{torch.hub.get_dir()}/s3prl_cache' |
|
os.makedirs(dirpath, exist_ok=True) |
|
filepath = f'{dirpath}/{filename}' |
|
with FileLock(filepath + ".lock"): |
|
if not os.path.isfile(filepath) or refresh: |
|
if agent == 'wget': |
|
os.system(f'wget {url} -O {filepath}') |
|
elif agent == 'gdown': |
|
gdown.download(url, filepath, use_cookies=False) |
|
else: |
|
print('[Download] - Unknown download agent. Only \'wget\' and \'gdown\' are supported.') |
|
raise NotImplementedError |
|
else: |
|
print(f'Using cache found in {filepath}\nfor {url}') |
|
return filepath |
|
|
|
|
|
def _urls_to_filepaths(*args, refresh=False, agent='wget'): |
|
""" |
|
Preprocess the URL specified in *args into local file paths after downloading |
|
|
|
Args: |
|
Any number of URLs (1 ~ any) |
|
|
|
Return: |
|
Same number of downloaded file paths |
|
""" |
|
|
|
def url_to_filename(url): |
|
assert type(url) is str |
|
m = hashlib.sha256() |
|
m.update(str.encode(url)) |
|
return str(m.hexdigest()) |
|
|
|
def url_to_path(url, refresh): |
|
if type(url) is str and len(url) > 0: |
|
return _download(url_to_filename(url), url, refresh, agent=agent) |
|
else: |
|
return None |
|
|
|
paths = [url_to_path(url, refresh) for url in args] |
|
return paths if len(paths) > 1 else paths[0] |
|
|
|
|
|
def _gdriveids_to_filepaths(*args, refresh=False): |
|
""" |
|
Preprocess the Google Drive id specified in *args into local file paths after downloading |
|
|
|
Args: |
|
Any number of Google Drive ids (1 ~ any) |
|
|
|
Return: |
|
Same number of downloaded file paths |
|
""" |
|
|
|
def gdriveid_to_url(gdriveid): |
|
if type(gdriveid) is str and len(gdriveid) > 0: |
|
return f'https://drive.google.com/uc?id={gdriveid}' |
|
else: |
|
return None |
|
|
|
return _urls_to_filepaths(*[gdriveid_to_url(gid) for gid in args], refresh=refresh, agent='gdown') |
|
|