Update convert_safetensors.py
Browse files- convert_safetensors.py +128 -128
convert_safetensors.py
CHANGED
@@ -1,129 +1,129 @@
|
|
1 |
-
import os
|
2 |
-
import glob
|
3 |
-
from safetensors import safe_open
|
4 |
-
from safetensors.torch import save_file
|
5 |
-
import torch
|
6 |
-
import json
|
7 |
-
|
8 |
-
# Model directory
|
9 |
-
model_dir = "xai-org/grok-2"
|
10 |
-
output_dir =
|
11 |
-
os.makedirs(output_dir, exist_ok=True)
|
12 |
-
|
13 |
-
# Collect all safetensors files
|
14 |
-
print("Collecting safetensors files...", flush=True)
|
15 |
-
safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors"))
|
16 |
-
if not safetensors_files:
|
17 |
-
raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}")
|
18 |
-
|
19 |
-
# Load all files into cache and build key-to-file mapping
|
20 |
-
file_cache = {} # file path -> {key: tensor}
|
21 |
-
key_to_files = {} # key -> [file paths]
|
22 |
-
total_size = 0
|
23 |
-
print("Loading safetensors files...", flush=True)
|
24 |
-
for file_path in safetensors_files:
|
25 |
-
try:
|
26 |
-
with safe_open(file_path, framework="pt", device="cpu") as f:
|
27 |
-
file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()}
|
28 |
-
for key, tensor in file_cache[file_path].items():
|
29 |
-
if key not in key_to_files:
|
30 |
-
key_to_files[key] = []
|
31 |
-
key_to_files[key].append(file_path)
|
32 |
-
total_size += tensor.element_size() * tensor.nelement()
|
33 |
-
except Exception as e:
|
34 |
-
print(f"Warning: Failed to load {file_path}: {e}")
|
35 |
-
print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True)
|
36 |
-
|
37 |
-
# Merge TP shards
|
38 |
-
tp_count = 8 # TP=8
|
39 |
-
merged_state_dict = {}
|
40 |
-
print("Merging TP shards...", flush=True)
|
41 |
-
for key, file_paths in key_to_files.items():
|
42 |
-
if len(file_paths) > 1: # TP shards
|
43 |
-
print(f"Merging {key} shards...", flush=True)
|
44 |
-
# Sort by TP number
|
45 |
-
sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1)
|
46 |
-
tensors = []
|
47 |
-
for file_path in sorted_paths[:tp_count]:
|
48 |
-
if file_path in file_cache and key in file_cache[file_path]:
|
49 |
-
tensors.append(file_cache[file_path][key])
|
50 |
-
else:
|
51 |
-
print(f"Warning: Key {key} missing in {file_path}")
|
52 |
-
if len(tensors) == tp_count:
|
53 |
-
try:
|
54 |
-
# Determine concatenation dimension
|
55 |
-
dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0
|
56 |
-
merged_tensor = torch.cat(tensors, dim=dim)
|
57 |
-
# Verify shape
|
58 |
-
if "block_sparse_moe.experts" in key:
|
59 |
-
if "w1.weight" in key or "w3.weight" in key:
|
60 |
-
expected_shape = (16384, 8192) # moe_intermediate_size, hidden_size
|
61 |
-
if merged_tensor.shape != expected_shape:
|
62 |
-
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
|
63 |
-
elif "w2.weight" in key:
|
64 |
-
expected_shape = (8192, 16384) # hidden_size, moe_intermediate_size
|
65 |
-
if merged_tensor.shape != expected_shape:
|
66 |
-
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
|
67 |
-
merged_state_dict[key] = merged_tensor
|
68 |
-
except Exception as e:
|
69 |
-
print(f"Failed to merge {key}: {e}")
|
70 |
-
merged_state_dict[key] = tensors[0] if tensors else None
|
71 |
-
else:
|
72 |
-
print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor")
|
73 |
-
merged_state_dict[key] = tensors[0] if tensors else None
|
74 |
-
else:
|
75 |
-
print(f"Processing {key} ...", flush=True)
|
76 |
-
# Non-TP shard
|
77 |
-
file_path = file_paths[0]
|
78 |
-
if file_path in file_cache and key in file_cache[file_path]:
|
79 |
-
merged_state_dict[key] = file_cache[file_path][key]
|
80 |
-
else:
|
81 |
-
print(f"Warning: Key {key} missing in {file_path}")
|
82 |
-
merged_state_dict[key] = None
|
83 |
-
|
84 |
-
# Group by layer
|
85 |
-
layer_dicts = {}
|
86 |
-
special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"]
|
87 |
-
last_layer_idx = None
|
88 |
-
print("Grouping weights by layer...", flush=True)
|
89 |
-
for key in list(merged_state_dict.keys()):
|
90 |
-
if merged_state_dict[key] is None:
|
91 |
-
continue
|
92 |
-
if key in special_weights:
|
93 |
-
continue
|
94 |
-
if "model.layers." in key:
|
95 |
-
layer_num = int(key.split(".")[2])
|
96 |
-
if layer_num not in layer_dicts:
|
97 |
-
layer_dicts[layer_num] = {}
|
98 |
-
layer_dicts[layer_num][key] = merged_state_dict.pop(key)
|
99 |
-
last_layer_idx = max(last_layer_idx or 0, layer_num)
|
100 |
-
|
101 |
-
# Save weights for each layer
|
102 |
-
print("Saving weight files...", flush=True)
|
103 |
-
for layer_num in sorted(layer_dicts.keys()):
|
104 |
-
output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors")
|
105 |
-
save_file(layer_dicts[layer_num], output_file)
|
106 |
-
print(f"Saved layer {layer_num} to {output_file}")
|
107 |
-
|
108 |
-
# Save final layer (including special weights)
|
109 |
-
last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors")
|
110 |
-
last_layer_dict = layer_dicts.get(last_layer_idx, {})
|
111 |
-
for key in special_weights:
|
112 |
-
if key in merged_state_dict and merged_state_dict[key] is not None:
|
113 |
-
last_layer_dict[key] = merged_state_dict[key]
|
114 |
-
save_file(last_layer_dict, last_layer_file)
|
115 |
-
print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True)
|
116 |
-
|
117 |
-
# Generate new index
|
118 |
-
new_index = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
119 |
-
for layer_num in sorted(layer_dicts.keys()):
|
120 |
-
file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors"
|
121 |
-
for key in layer_dicts[layer_num]:
|
122 |
-
new_index["weight_map"][key] = file_name
|
123 |
-
for key in special_weights:
|
124 |
-
if key in merged_state_dict and merged_state_dict[key] is not None:
|
125 |
-
new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors"
|
126 |
-
|
127 |
-
with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f:
|
128 |
-
json.dump(new_index, f, indent=2)
|
129 |
print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True)
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
from safetensors import safe_open
|
4 |
+
from safetensors.torch import save_file
|
5 |
+
import torch
|
6 |
+
import json
|
7 |
+
|
8 |
+
# Model directory
|
9 |
+
model_dir = "xai-org/grok-2"
|
10 |
+
output_dir = "huihui-ai/grok-2"
|
11 |
+
os.makedirs(output_dir, exist_ok=True)
|
12 |
+
|
13 |
+
# Collect all safetensors files
|
14 |
+
print("Collecting safetensors files...", flush=True)
|
15 |
+
safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors"))
|
16 |
+
if not safetensors_files:
|
17 |
+
raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}")
|
18 |
+
|
19 |
+
# Load all files into cache and build key-to-file mapping
|
20 |
+
file_cache = {} # file path -> {key: tensor}
|
21 |
+
key_to_files = {} # key -> [file paths]
|
22 |
+
total_size = 0
|
23 |
+
print("Loading safetensors files...", flush=True)
|
24 |
+
for file_path in safetensors_files:
|
25 |
+
try:
|
26 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
27 |
+
file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()}
|
28 |
+
for key, tensor in file_cache[file_path].items():
|
29 |
+
if key not in key_to_files:
|
30 |
+
key_to_files[key] = []
|
31 |
+
key_to_files[key].append(file_path)
|
32 |
+
total_size += tensor.element_size() * tensor.nelement()
|
33 |
+
except Exception as e:
|
34 |
+
print(f"Warning: Failed to load {file_path}: {e}")
|
35 |
+
print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True)
|
36 |
+
|
37 |
+
# Merge TP shards
|
38 |
+
tp_count = 8 # TP=8
|
39 |
+
merged_state_dict = {}
|
40 |
+
print("Merging TP shards...", flush=True)
|
41 |
+
for key, file_paths in key_to_files.items():
|
42 |
+
if len(file_paths) > 1: # TP shards
|
43 |
+
print(f"Merging {key} shards...", flush=True)
|
44 |
+
# Sort by TP number
|
45 |
+
sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1)
|
46 |
+
tensors = []
|
47 |
+
for file_path in sorted_paths[:tp_count]:
|
48 |
+
if file_path in file_cache and key in file_cache[file_path]:
|
49 |
+
tensors.append(file_cache[file_path][key])
|
50 |
+
else:
|
51 |
+
print(f"Warning: Key {key} missing in {file_path}")
|
52 |
+
if len(tensors) == tp_count:
|
53 |
+
try:
|
54 |
+
# Determine concatenation dimension
|
55 |
+
dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0
|
56 |
+
merged_tensor = torch.cat(tensors, dim=dim)
|
57 |
+
# Verify shape
|
58 |
+
if "block_sparse_moe.experts" in key:
|
59 |
+
if "w1.weight" in key or "w3.weight" in key:
|
60 |
+
expected_shape = (16384, 8192) # moe_intermediate_size, hidden_size
|
61 |
+
if merged_tensor.shape != expected_shape:
|
62 |
+
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
|
63 |
+
elif "w2.weight" in key:
|
64 |
+
expected_shape = (8192, 16384) # hidden_size, moe_intermediate_size
|
65 |
+
if merged_tensor.shape != expected_shape:
|
66 |
+
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
|
67 |
+
merged_state_dict[key] = merged_tensor
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Failed to merge {key}: {e}")
|
70 |
+
merged_state_dict[key] = tensors[0] if tensors else None
|
71 |
+
else:
|
72 |
+
print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor")
|
73 |
+
merged_state_dict[key] = tensors[0] if tensors else None
|
74 |
+
else:
|
75 |
+
print(f"Processing {key} ...", flush=True)
|
76 |
+
# Non-TP shard
|
77 |
+
file_path = file_paths[0]
|
78 |
+
if file_path in file_cache and key in file_cache[file_path]:
|
79 |
+
merged_state_dict[key] = file_cache[file_path][key]
|
80 |
+
else:
|
81 |
+
print(f"Warning: Key {key} missing in {file_path}")
|
82 |
+
merged_state_dict[key] = None
|
83 |
+
|
84 |
+
# Group by layer
|
85 |
+
layer_dicts = {}
|
86 |
+
special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"]
|
87 |
+
last_layer_idx = None
|
88 |
+
print("Grouping weights by layer...", flush=True)
|
89 |
+
for key in list(merged_state_dict.keys()):
|
90 |
+
if merged_state_dict[key] is None:
|
91 |
+
continue
|
92 |
+
if key in special_weights:
|
93 |
+
continue
|
94 |
+
if "model.layers." in key:
|
95 |
+
layer_num = int(key.split(".")[2])
|
96 |
+
if layer_num not in layer_dicts:
|
97 |
+
layer_dicts[layer_num] = {}
|
98 |
+
layer_dicts[layer_num][key] = merged_state_dict.pop(key)
|
99 |
+
last_layer_idx = max(last_layer_idx or 0, layer_num)
|
100 |
+
|
101 |
+
# Save weights for each layer
|
102 |
+
print("Saving weight files...", flush=True)
|
103 |
+
for layer_num in sorted(layer_dicts.keys()):
|
104 |
+
output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors")
|
105 |
+
save_file(layer_dicts[layer_num], output_file)
|
106 |
+
print(f"Saved layer {layer_num} to {output_file}")
|
107 |
+
|
108 |
+
# Save final layer (including special weights)
|
109 |
+
last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors")
|
110 |
+
last_layer_dict = layer_dicts.get(last_layer_idx, {})
|
111 |
+
for key in special_weights:
|
112 |
+
if key in merged_state_dict and merged_state_dict[key] is not None:
|
113 |
+
last_layer_dict[key] = merged_state_dict[key]
|
114 |
+
save_file(last_layer_dict, last_layer_file)
|
115 |
+
print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True)
|
116 |
+
|
117 |
+
# Generate new index
|
118 |
+
new_index = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
119 |
+
for layer_num in sorted(layer_dicts.keys()):
|
120 |
+
file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors"
|
121 |
+
for key in layer_dicts[layer_num]:
|
122 |
+
new_index["weight_map"][key] = file_name
|
123 |
+
for key in special_weights:
|
124 |
+
if key in merged_state_dict and merged_state_dict[key] is not None:
|
125 |
+
new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors"
|
126 |
+
|
127 |
+
with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f:
|
128 |
+
json.dump(new_index, f, indent=2)
|
129 |
print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True)
|