grok-2 / convert_safetensors.py
huihui-ai's picture
Update convert_safetensors.py
b63cb16 verified
raw
history blame
6.05 kB
import os
import glob
from safetensors import safe_open
from safetensors.torch import save_file
import torch
import json
# Model directory
model_dir = "xai-org/grok-2"
output_dir = "huihui-ai/grok-2"
os.makedirs(output_dir, exist_ok=True)
# Collect all safetensors files
print("Collecting safetensors files...", flush=True)
safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors"))
if not safetensors_files:
raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}")
# Load all files into cache and build key-to-file mapping
file_cache = {} # file path -> {key: tensor}
key_to_files = {} # key -> [file paths]
total_size = 0
print("Loading safetensors files...", flush=True)
for file_path in safetensors_files:
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()}
for key, tensor in file_cache[file_path].items():
if key not in key_to_files:
key_to_files[key] = []
key_to_files[key].append(file_path)
total_size += tensor.element_size() * tensor.nelement()
except Exception as e:
print(f"Warning: Failed to load {file_path}: {e}")
print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True)
# Merge TP shards
tp_count = 8 # TP=8
merged_state_dict = {}
print("Merging TP shards...", flush=True)
for key, file_paths in key_to_files.items():
if len(file_paths) > 1: # TP shards
print(f"Merging {key} shards...", flush=True)
# Sort by TP number
sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1)
tensors = []
for file_path in sorted_paths[:tp_count]:
if file_path in file_cache and key in file_cache[file_path]:
tensors.append(file_cache[file_path][key])
else:
print(f"Warning: Key {key} missing in {file_path}")
if len(tensors) == tp_count:
try:
# Determine concatenation dimension
dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0
merged_tensor = torch.cat(tensors, dim=dim)
# Verify shape
if "block_sparse_moe.experts" in key:
if "w1.weight" in key or "w3.weight" in key:
expected_shape = (16384, 8192) # moe_intermediate_size, hidden_size
if merged_tensor.shape != expected_shape:
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
elif "w2.weight" in key:
expected_shape = (8192, 16384) # hidden_size, moe_intermediate_size
if merged_tensor.shape != expected_shape:
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
merged_state_dict[key] = merged_tensor
except Exception as e:
print(f"Failed to merge {key}: {e}")
merged_state_dict[key] = tensors[0] if tensors else None
else:
print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor")
merged_state_dict[key] = tensors[0] if tensors else None
else:
print(f"Processing {key} ...", flush=True)
# Non-TP shard
file_path = file_paths[0]
if file_path in file_cache and key in file_cache[file_path]:
merged_state_dict[key] = file_cache[file_path][key]
else:
print(f"Warning: Key {key} missing in {file_path}")
merged_state_dict[key] = None
# Group by layer
layer_dicts = {}
special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"]
last_layer_idx = None
print("Grouping weights by layer...", flush=True)
for key in list(merged_state_dict.keys()):
if merged_state_dict[key] is None:
continue
if key in special_weights:
continue
if "model.layers." in key:
layer_num = int(key.split(".")[2])
if layer_num not in layer_dicts:
layer_dicts[layer_num] = {}
layer_dicts[layer_num][key] = merged_state_dict.pop(key)
last_layer_idx = max(last_layer_idx or 0, layer_num)
# Save weights for each layer
print("Saving weight files...", flush=True)
for layer_num in sorted(layer_dicts.keys()):
output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors")
save_file(layer_dicts[layer_num], output_file)
print(f"Saved layer {layer_num} to {output_file}")
# Save final layer (including special weights)
last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors")
last_layer_dict = layer_dicts.get(last_layer_idx, {})
for key in special_weights:
if key in merged_state_dict and merged_state_dict[key] is not None:
last_layer_dict[key] = merged_state_dict[key]
save_file(last_layer_dict, last_layer_file)
print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True)
# Generate new index
new_index = {"metadata": {"total_size": total_size}, "weight_map": {}}
for layer_num in sorted(layer_dicts.keys()):
file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors"
for key in layer_dicts[layer_num]:
new_index["weight_map"][key] = file_name
for key in special_weights:
if key in merged_state_dict and merged_state_dict[key] is not None:
new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors"
with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f:
json.dump(new_index, f, indent=2)
print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True)