|
import os |
|
import glob |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
import torch |
|
import json |
|
|
|
|
|
model_dir = "xai-org/grok-2" |
|
output_dir = "huihui-ai/grok-2" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print("Collecting safetensors files...", flush=True) |
|
safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors")) |
|
if not safetensors_files: |
|
raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}") |
|
|
|
|
|
file_cache = {} |
|
key_to_files = {} |
|
total_size = 0 |
|
print("Loading safetensors files...", flush=True) |
|
for file_path in safetensors_files: |
|
try: |
|
with safe_open(file_path, framework="pt", device="cpu") as f: |
|
file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()} |
|
for key, tensor in file_cache[file_path].items(): |
|
if key not in key_to_files: |
|
key_to_files[key] = [] |
|
key_to_files[key].append(file_path) |
|
total_size += tensor.element_size() * tensor.nelement() |
|
except Exception as e: |
|
print(f"Warning: Failed to load {file_path}: {e}") |
|
print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True) |
|
|
|
|
|
tp_count = 8 |
|
merged_state_dict = {} |
|
print("Merging TP shards...", flush=True) |
|
for key, file_paths in key_to_files.items(): |
|
if len(file_paths) > 1: |
|
print(f"Merging {key} shards...", flush=True) |
|
|
|
sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1) |
|
tensors = [] |
|
for file_path in sorted_paths[:tp_count]: |
|
if file_path in file_cache and key in file_cache[file_path]: |
|
tensors.append(file_cache[file_path][key]) |
|
else: |
|
print(f"Warning: Key {key} missing in {file_path}") |
|
if len(tensors) == tp_count: |
|
try: |
|
|
|
dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0 |
|
merged_tensor = torch.cat(tensors, dim=dim) |
|
|
|
if "block_sparse_moe.experts" in key: |
|
if "w1.weight" in key or "w3.weight" in key: |
|
expected_shape = (16384, 8192) |
|
if merged_tensor.shape != expected_shape: |
|
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}") |
|
elif "w2.weight" in key: |
|
expected_shape = (8192, 16384) |
|
if merged_tensor.shape != expected_shape: |
|
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}") |
|
merged_state_dict[key] = merged_tensor |
|
except Exception as e: |
|
print(f"Failed to merge {key}: {e}") |
|
merged_state_dict[key] = tensors[0] if tensors else None |
|
else: |
|
print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor") |
|
merged_state_dict[key] = tensors[0] if tensors else None |
|
else: |
|
print(f"Processing {key} ...", flush=True) |
|
|
|
file_path = file_paths[0] |
|
if file_path in file_cache and key in file_cache[file_path]: |
|
merged_state_dict[key] = file_cache[file_path][key] |
|
else: |
|
print(f"Warning: Key {key} missing in {file_path}") |
|
merged_state_dict[key] = None |
|
|
|
|
|
layer_dicts = {} |
|
special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"] |
|
last_layer_idx = None |
|
print("Grouping weights by layer...", flush=True) |
|
for key in list(merged_state_dict.keys()): |
|
if merged_state_dict[key] is None: |
|
continue |
|
if key in special_weights: |
|
continue |
|
if "model.layers." in key: |
|
layer_num = int(key.split(".")[2]) |
|
if layer_num not in layer_dicts: |
|
layer_dicts[layer_num] = {} |
|
layer_dicts[layer_num][key] = merged_state_dict.pop(key) |
|
last_layer_idx = max(last_layer_idx or 0, layer_num) |
|
|
|
|
|
print("Saving weight files...", flush=True) |
|
for layer_num in sorted(layer_dicts.keys()): |
|
output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors") |
|
save_file(layer_dicts[layer_num], output_file) |
|
print(f"Saved layer {layer_num} to {output_file}") |
|
|
|
|
|
last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors") |
|
last_layer_dict = layer_dicts.get(last_layer_idx, {}) |
|
for key in special_weights: |
|
if key in merged_state_dict and merged_state_dict[key] is not None: |
|
last_layer_dict[key] = merged_state_dict[key] |
|
save_file(last_layer_dict, last_layer_file) |
|
print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True) |
|
|
|
|
|
new_index = {"metadata": {"total_size": total_size}, "weight_map": {}} |
|
for layer_num in sorted(layer_dicts.keys()): |
|
file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors" |
|
for key in layer_dicts[layer_num]: |
|
new_index["weight_map"][key] = file_name |
|
for key in special_weights: |
|
if key in merged_state_dict and merged_state_dict[key] is not None: |
|
new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors" |
|
|
|
with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f: |
|
json.dump(new_index, f, indent=2) |
|
print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True) |