MS3.2-24b-Angel_mlx-q8 / rename_tensors.py
heni86's picture
Upload folder using huggingface_hub
33a482f verified
from pathlib import Path
import os
import sys
from safetensors.torch import load_file, save_file
import torch # noqa: F401 (needed so tensors load onto CPU)
PREFIX_OLD = "lm_head."
PREFIX_NEW = "language_model.lm_head."
def rename_keys(tensor_dict: dict) -> dict:
"""Return a new dict with renamed keys."""
out = {}
for name, tensor in tensor_dict.items():
if name.startswith(PREFIX_OLD):
name = PREFIX_NEW + name[len(PREFIX_OLD):]
out[name] = tensor
return out
def process_file(path: Path) -> None:
print(f"Processing {path}")
data = load_file(str(path), device="cpu")
renamed = rename_keys(data)
if renamed.keys() == data.keys():
print(" No keys needed renaming. Skipping.")
return
tmp_path = path.with_suffix(".safetensors.tmp")
save_file(renamed, str(tmp_path))
os.replace(tmp_path, path) # atomic on POSIX
print(" Updated.")
def main() -> None:
files = sorted(Path(".").glob("model-*.safetensors"))
if not files:
print("No model-*.safetensors files found.", file=sys.stderr)
sys.exit(1)
for f in files:
process_file(f)
if __name__ == "__main__":
main()