File size: 1,213 Bytes
33a482f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from pathlib import Path
import os
import sys
from safetensors.torch import load_file, save_file
import torch # noqa: F401 (needed so tensors load onto CPU)
PREFIX_OLD = "lm_head."
PREFIX_NEW = "language_model.lm_head."
def rename_keys(tensor_dict: dict) -> dict:
"""Return a new dict with renamed keys."""
out = {}
for name, tensor in tensor_dict.items():
if name.startswith(PREFIX_OLD):
name = PREFIX_NEW + name[len(PREFIX_OLD):]
out[name] = tensor
return out
def process_file(path: Path) -> None:
print(f"Processing {path}")
data = load_file(str(path), device="cpu")
renamed = rename_keys(data)
if renamed.keys() == data.keys():
print(" No keys needed renaming. Skipping.")
return
tmp_path = path.with_suffix(".safetensors.tmp")
save_file(renamed, str(tmp_path))
os.replace(tmp_path, path) # atomic on POSIX
print(" Updated.")
def main() -> None:
files = sorted(Path(".").glob("model-*.safetensors"))
if not files:
print("No model-*.safetensors files found.", file=sys.stderr)
sys.exit(1)
for f in files:
process_file(f)
if __name__ == "__main__":
main()
|