| import torch | |
| import sys | |
| if __name__ == "__main__": | |
| inpath = sys.argv[1] | |
| outpath = sys.argv[2] | |
| submodel = "cond_stage_model" | |
| if len(sys.argv) > 3: | |
| submodel = sys.argv[3] | |
| print("Extracting {} from {} to {}.".format(submodel, inpath, outpath)) | |
| sd = torch.load(inpath, map_location="cpu") | |
| new_sd = {"state_dict": dict((k.split(".", 1)[-1],v) | |
| for k,v in sd["state_dict"].items() | |
| if k.startswith("cond_stage_model"))} | |
| torch.save(new_sd, outpath) | |