File size: 1,213 Bytes
33a482f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from pathlib import Path
import os
import sys
from safetensors.torch import load_file, save_file
import torch  # noqa: F401  (needed so tensors load onto CPU)

PREFIX_OLD = "lm_head."
PREFIX_NEW = "language_model.lm_head."


def rename_keys(tensor_dict: dict) -> dict:
    """Return a new dict with renamed keys."""
    out = {}
    for name, tensor in tensor_dict.items():
        if name.startswith(PREFIX_OLD):
            name = PREFIX_NEW + name[len(PREFIX_OLD):]
        out[name] = tensor
    return out


def process_file(path: Path) -> None:
    print(f"Processing {path}")
    data = load_file(str(path), device="cpu")
    renamed = rename_keys(data)

    if renamed.keys() == data.keys():
        print("  No keys needed renaming. Skipping.")
        return

    tmp_path = path.with_suffix(".safetensors.tmp")
    save_file(renamed, str(tmp_path))
    os.replace(tmp_path, path)        # atomic on POSIX
    print("  Updated.")


def main() -> None:
    files = sorted(Path(".").glob("model-*.safetensors"))
    if not files:
        print("No model-*.safetensors files found.", file=sys.stderr)
        sys.exit(1)

    for f in files:
        process_file(f)


if __name__ == "__main__":
    main()