from pathlib import Path import os import sys from safetensors.torch import load_file, save_file import torch # noqa: F401 (needed so tensors load onto CPU) PREFIX_OLD = "lm_head." PREFIX_NEW = "language_model.lm_head." def rename_keys(tensor_dict: dict) -> dict: """Return a new dict with renamed keys.""" out = {} for name, tensor in tensor_dict.items(): if name.startswith(PREFIX_OLD): name = PREFIX_NEW + name[len(PREFIX_OLD):] out[name] = tensor return out def process_file(path: Path) -> None: print(f"Processing {path}") data = load_file(str(path), device="cpu") renamed = rename_keys(data) if renamed.keys() == data.keys(): print(" No keys needed renaming. Skipping.") return tmp_path = path.with_suffix(".safetensors.tmp") save_file(renamed, str(tmp_path)) os.replace(tmp_path, path) # atomic on POSIX print(" Updated.") def main() -> None: files = sorted(Path(".").glob("model-*.safetensors")) if not files: print("No model-*.safetensors files found.", file=sys.stderr) sys.exit(1) for f in files: process_file(f) if __name__ == "__main__": main()