huihui-ai commited on
Commit
b63cb16
·
verified ·
1 Parent(s): f13f38b

Update convert_safetensors.py

Browse files
Files changed (1) hide show
  1. convert_safetensors.py +128 -128
convert_safetensors.py CHANGED
@@ -1,129 +1,129 @@
1
- import os
2
- import glob
3
- from safetensors import safe_open
4
- from safetensors.torch import save_file
5
- import torch
6
- import json
7
-
8
- # Model directory
9
- model_dir = "xai-org/grok-2"
10
- output_dir = os.path.join(model_dir, "reorganized")
11
- os.makedirs(output_dir, exist_ok=True)
12
-
13
- # Collect all safetensors files
14
- print("Collecting safetensors files...", flush=True)
15
- safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors"))
16
- if not safetensors_files:
17
- raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}")
18
-
19
- # Load all files into cache and build key-to-file mapping
20
- file_cache = {} # file path -> {key: tensor}
21
- key_to_files = {} # key -> [file paths]
22
- total_size = 0
23
- print("Loading safetensors files...", flush=True)
24
- for file_path in safetensors_files:
25
- try:
26
- with safe_open(file_path, framework="pt", device="cpu") as f:
27
- file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()}
28
- for key, tensor in file_cache[file_path].items():
29
- if key not in key_to_files:
30
- key_to_files[key] = []
31
- key_to_files[key].append(file_path)
32
- total_size += tensor.element_size() * tensor.nelement()
33
- except Exception as e:
34
- print(f"Warning: Failed to load {file_path}: {e}")
35
- print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True)
36
-
37
- # Merge TP shards
38
- tp_count = 8 # TP=8
39
- merged_state_dict = {}
40
- print("Merging TP shards...", flush=True)
41
- for key, file_paths in key_to_files.items():
42
- if len(file_paths) > 1: # TP shards
43
- print(f"Merging {key} shards...", flush=True)
44
- # Sort by TP number
45
- sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1)
46
- tensors = []
47
- for file_path in sorted_paths[:tp_count]:
48
- if file_path in file_cache and key in file_cache[file_path]:
49
- tensors.append(file_cache[file_path][key])
50
- else:
51
- print(f"Warning: Key {key} missing in {file_path}")
52
- if len(tensors) == tp_count:
53
- try:
54
- # Determine concatenation dimension
55
- dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0
56
- merged_tensor = torch.cat(tensors, dim=dim)
57
- # Verify shape
58
- if "block_sparse_moe.experts" in key:
59
- if "w1.weight" in key or "w3.weight" in key:
60
- expected_shape = (16384, 8192) # moe_intermediate_size, hidden_size
61
- if merged_tensor.shape != expected_shape:
62
- print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
63
- elif "w2.weight" in key:
64
- expected_shape = (8192, 16384) # hidden_size, moe_intermediate_size
65
- if merged_tensor.shape != expected_shape:
66
- print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
67
- merged_state_dict[key] = merged_tensor
68
- except Exception as e:
69
- print(f"Failed to merge {key}: {e}")
70
- merged_state_dict[key] = tensors[0] if tensors else None
71
- else:
72
- print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor")
73
- merged_state_dict[key] = tensors[0] if tensors else None
74
- else:
75
- print(f"Processing {key} ...", flush=True)
76
- # Non-TP shard
77
- file_path = file_paths[0]
78
- if file_path in file_cache and key in file_cache[file_path]:
79
- merged_state_dict[key] = file_cache[file_path][key]
80
- else:
81
- print(f"Warning: Key {key} missing in {file_path}")
82
- merged_state_dict[key] = None
83
-
84
- # Group by layer
85
- layer_dicts = {}
86
- special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"]
87
- last_layer_idx = None
88
- print("Grouping weights by layer...", flush=True)
89
- for key in list(merged_state_dict.keys()):
90
- if merged_state_dict[key] is None:
91
- continue
92
- if key in special_weights:
93
- continue
94
- if "model.layers." in key:
95
- layer_num = int(key.split(".")[2])
96
- if layer_num not in layer_dicts:
97
- layer_dicts[layer_num] = {}
98
- layer_dicts[layer_num][key] = merged_state_dict.pop(key)
99
- last_layer_idx = max(last_layer_idx or 0, layer_num)
100
-
101
- # Save weights for each layer
102
- print("Saving weight files...", flush=True)
103
- for layer_num in sorted(layer_dicts.keys()):
104
- output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors")
105
- save_file(layer_dicts[layer_num], output_file)
106
- print(f"Saved layer {layer_num} to {output_file}")
107
-
108
- # Save final layer (including special weights)
109
- last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors")
110
- last_layer_dict = layer_dicts.get(last_layer_idx, {})
111
- for key in special_weights:
112
- if key in merged_state_dict and merged_state_dict[key] is not None:
113
- last_layer_dict[key] = merged_state_dict[key]
114
- save_file(last_layer_dict, last_layer_file)
115
- print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True)
116
-
117
- # Generate new index
118
- new_index = {"metadata": {"total_size": total_size}, "weight_map": {}}
119
- for layer_num in sorted(layer_dicts.keys()):
120
- file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors"
121
- for key in layer_dicts[layer_num]:
122
- new_index["weight_map"][key] = file_name
123
- for key in special_weights:
124
- if key in merged_state_dict and merged_state_dict[key] is not None:
125
- new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors"
126
-
127
- with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f:
128
- json.dump(new_index, f, indent=2)
129
  print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True)
 
1
+ import os
2
+ import glob
3
+ from safetensors import safe_open
4
+ from safetensors.torch import save_file
5
+ import torch
6
+ import json
7
+
8
+ # Model directory
9
+ model_dir = "xai-org/grok-2"
10
+ output_dir = "huihui-ai/grok-2"
11
+ os.makedirs(output_dir, exist_ok=True)
12
+
13
+ # Collect all safetensors files
14
+ print("Collecting safetensors files...", flush=True)
15
+ safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors"))
16
+ if not safetensors_files:
17
+ raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}")
18
+
19
+ # Load all files into cache and build key-to-file mapping
20
+ file_cache = {} # file path -> {key: tensor}
21
+ key_to_files = {} # key -> [file paths]
22
+ total_size = 0
23
+ print("Loading safetensors files...", flush=True)
24
+ for file_path in safetensors_files:
25
+ try:
26
+ with safe_open(file_path, framework="pt", device="cpu") as f:
27
+ file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()}
28
+ for key, tensor in file_cache[file_path].items():
29
+ if key not in key_to_files:
30
+ key_to_files[key] = []
31
+ key_to_files[key].append(file_path)
32
+ total_size += tensor.element_size() * tensor.nelement()
33
+ except Exception as e:
34
+ print(f"Warning: Failed to load {file_path}: {e}")
35
+ print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True)
36
+
37
+ # Merge TP shards
38
+ tp_count = 8 # TP=8
39
+ merged_state_dict = {}
40
+ print("Merging TP shards...", flush=True)
41
+ for key, file_paths in key_to_files.items():
42
+ if len(file_paths) > 1: # TP shards
43
+ print(f"Merging {key} shards...", flush=True)
44
+ # Sort by TP number
45
+ sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1)
46
+ tensors = []
47
+ for file_path in sorted_paths[:tp_count]:
48
+ if file_path in file_cache and key in file_cache[file_path]:
49
+ tensors.append(file_cache[file_path][key])
50
+ else:
51
+ print(f"Warning: Key {key} missing in {file_path}")
52
+ if len(tensors) == tp_count:
53
+ try:
54
+ # Determine concatenation dimension
55
+ dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0
56
+ merged_tensor = torch.cat(tensors, dim=dim)
57
+ # Verify shape
58
+ if "block_sparse_moe.experts" in key:
59
+ if "w1.weight" in key or "w3.weight" in key:
60
+ expected_shape = (16384, 8192) # moe_intermediate_size, hidden_size
61
+ if merged_tensor.shape != expected_shape:
62
+ print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
63
+ elif "w2.weight" in key:
64
+ expected_shape = (8192, 16384) # hidden_size, moe_intermediate_size
65
+ if merged_tensor.shape != expected_shape:
66
+ print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
67
+ merged_state_dict[key] = merged_tensor
68
+ except Exception as e:
69
+ print(f"Failed to merge {key}: {e}")
70
+ merged_state_dict[key] = tensors[0] if tensors else None
71
+ else:
72
+ print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor")
73
+ merged_state_dict[key] = tensors[0] if tensors else None
74
+ else:
75
+ print(f"Processing {key} ...", flush=True)
76
+ # Non-TP shard
77
+ file_path = file_paths[0]
78
+ if file_path in file_cache and key in file_cache[file_path]:
79
+ merged_state_dict[key] = file_cache[file_path][key]
80
+ else:
81
+ print(f"Warning: Key {key} missing in {file_path}")
82
+ merged_state_dict[key] = None
83
+
84
+ # Group by layer
85
+ layer_dicts = {}
86
+ special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"]
87
+ last_layer_idx = None
88
+ print("Grouping weights by layer...", flush=True)
89
+ for key in list(merged_state_dict.keys()):
90
+ if merged_state_dict[key] is None:
91
+ continue
92
+ if key in special_weights:
93
+ continue
94
+ if "model.layers." in key:
95
+ layer_num = int(key.split(".")[2])
96
+ if layer_num not in layer_dicts:
97
+ layer_dicts[layer_num] = {}
98
+ layer_dicts[layer_num][key] = merged_state_dict.pop(key)
99
+ last_layer_idx = max(last_layer_idx or 0, layer_num)
100
+
101
+ # Save weights for each layer
102
+ print("Saving weight files...", flush=True)
103
+ for layer_num in sorted(layer_dicts.keys()):
104
+ output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors")
105
+ save_file(layer_dicts[layer_num], output_file)
106
+ print(f"Saved layer {layer_num} to {output_file}")
107
+
108
+ # Save final layer (including special weights)
109
+ last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors")
110
+ last_layer_dict = layer_dicts.get(last_layer_idx, {})
111
+ for key in special_weights:
112
+ if key in merged_state_dict and merged_state_dict[key] is not None:
113
+ last_layer_dict[key] = merged_state_dict[key]
114
+ save_file(last_layer_dict, last_layer_file)
115
+ print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True)
116
+
117
+ # Generate new index
118
+ new_index = {"metadata": {"total_size": total_size}, "weight_map": {}}
119
+ for layer_num in sorted(layer_dicts.keys()):
120
+ file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors"
121
+ for key in layer_dicts[layer_num]:
122
+ new_index["weight_map"][key] = file_name
123
+ for key in special_weights:
124
+ if key in merged_state_dict and merged_state_dict[key] is not None:
125
+ new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors"
126
+
127
+ with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f:
128
+ json.dump(new_index, f, indent=2)
129
  print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True)