import argparse import torch import hubconf parser = argparse.ArgumentParser() upstreams = [attr for attr in dir(hubconf) if callable(getattr(hubconf, attr)) and attr[0] != '_'] parser.add_argument('--mode', choices=['list', 'help', 'load'], required=True) parser.add_argument('--upstream', choices=upstreams) parser.add_argument('--ckpt', help='The PATH/URL/GOOGLE_DRIVE_ID of upstream checkpoint, not always needed') parser.add_argument('--config', help='The PATH of upstream config, not always needed') parser.add_argument('--refresh', action='store_true', help='Whether to re-download upstream contents') args = parser.parse_args() if args.mode == 'list': print(torch.hub.list('s3prl/s3prl', force_reload=args.refresh)) elif args.mode == 'help': print(torch.hub.help('s3prl/s3prl', args.upstream, force_reload=args.refresh)) elif args.mode == 'load': print(torch.hub.load( 's3prl/s3prl', args.upstream, force_reload=args.refresh, ckpt=args.ckpt, config=args.config, refresh=args.refresh ))