from pathlib import Path | |
import os | |
import sys | |
from safetensors.torch import load_file, save_file | |
import torch # noqa: F401 (needed so tensors load onto CPU) | |
PREFIX_OLD = "lm_head." | |
PREFIX_NEW = "language_model.lm_head." | |
def rename_keys(tensor_dict: dict) -> dict: | |
"""Return a new dict with renamed keys.""" | |
out = {} | |
for name, tensor in tensor_dict.items(): | |
if name.startswith(PREFIX_OLD): | |
name = PREFIX_NEW + name[len(PREFIX_OLD):] | |
out[name] = tensor | |
return out | |
def process_file(path: Path) -> None: | |
print(f"Processing {path}") | |
data = load_file(str(path), device="cpu") | |
renamed = rename_keys(data) | |
if renamed.keys() == data.keys(): | |
print(" No keys needed renaming. Skipping.") | |
return | |
tmp_path = path.with_suffix(".safetensors.tmp") | |
save_file(renamed, str(tmp_path)) | |
os.replace(tmp_path, path) # atomic on POSIX | |
print(" Updated.") | |
def main() -> None: | |
files = sorted(Path(".").glob("model-*.safetensors")) | |
if not files: | |
print("No model-*.safetensors files found.", file=sys.stderr) | |
sys.exit(1) | |
for f in files: | |
process_file(f) | |
if __name__ == "__main__": | |
main() | |