lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
1.03 kB
import argparse
import torch
import hubconf
parser = argparse.ArgumentParser()
upstreams = [attr for attr in dir(hubconf) if callable(getattr(hubconf, attr)) and attr[0] != '_']
parser.add_argument('--mode', choices=['list', 'help', 'load'], required=True)
parser.add_argument('--upstream', choices=upstreams)
parser.add_argument('--ckpt', help='The PATH/URL/GOOGLE_DRIVE_ID of upstream checkpoint, not always needed')
parser.add_argument('--config', help='The PATH of upstream config, not always needed')
parser.add_argument('--refresh', action='store_true', help='Whether to re-download upstream contents')
args = parser.parse_args()
if args.mode == 'list':
print(torch.hub.list('s3prl/s3prl', force_reload=args.refresh))
elif args.mode == 'help':
print(torch.hub.help('s3prl/s3prl', args.upstream, force_reload=args.refresh))
elif args.mode == 'load':
print(torch.hub.load(
's3prl/s3prl', args.upstream, force_reload=args.refresh,
ckpt=args.ckpt, config=args.config, refresh=args.refresh
))