import os import glob from safetensors import safe_open from safetensors.torch import save_file import torch import json # Model directory model_dir = "xai-org/grok-2" output_dir = "huihui-ai/grok-2" os.makedirs(output_dir, exist_ok=True) # Collect all safetensors files print("Collecting safetensors files...", flush=True) safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors")) if not safetensors_files: raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}") # Load all files into cache and build key-to-file mapping file_cache = {} # file path -> {key: tensor} key_to_files = {} # key -> [file paths] total_size = 0 print("Loading safetensors files...", flush=True) for file_path in safetensors_files: try: with safe_open(file_path, framework="pt", device="cpu") as f: file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()} for key, tensor in file_cache[file_path].items(): if key not in key_to_files: key_to_files[key] = [] key_to_files[key].append(file_path) total_size += tensor.element_size() * tensor.nelement() except Exception as e: print(f"Warning: Failed to load {file_path}: {e}") print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True) # Merge TP shards tp_count = 8 # TP=8 merged_state_dict = {} print("Merging TP shards...", flush=True) for key, file_paths in key_to_files.items(): if len(file_paths) > 1: # TP shards print(f"Merging {key} shards...", flush=True) # Sort by TP number sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1) tensors = [] for file_path in sorted_paths[:tp_count]: if file_path in file_cache and key in file_cache[file_path]: tensors.append(file_cache[file_path][key]) else: print(f"Warning: Key {key} missing in {file_path}") if len(tensors) == tp_count: try: # Determine concatenation dimension dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0 merged_tensor = torch.cat(tensors, dim=dim) # Verify shape if "block_sparse_moe.experts" in key: if "w1.weight" in key or "w3.weight" in key: expected_shape = (16384, 8192) # moe_intermediate_size, hidden_size if merged_tensor.shape != expected_shape: print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}") elif "w2.weight" in key: expected_shape = (8192, 16384) # hidden_size, moe_intermediate_size if merged_tensor.shape != expected_shape: print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}") merged_state_dict[key] = merged_tensor except Exception as e: print(f"Failed to merge {key}: {e}") merged_state_dict[key] = tensors[0] if tensors else None else: print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor") merged_state_dict[key] = tensors[0] if tensors else None else: print(f"Processing {key} ...", flush=True) # Non-TP shard file_path = file_paths[0] if file_path in file_cache and key in file_cache[file_path]: merged_state_dict[key] = file_cache[file_path][key] else: print(f"Warning: Key {key} missing in {file_path}") merged_state_dict[key] = None # Group by layer layer_dicts = {} special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"] last_layer_idx = None print("Grouping weights by layer...", flush=True) for key in list(merged_state_dict.keys()): if merged_state_dict[key] is None: continue if key in special_weights: continue if "model.layers." in key: layer_num = int(key.split(".")[2]) if layer_num not in layer_dicts: layer_dicts[layer_num] = {} layer_dicts[layer_num][key] = merged_state_dict.pop(key) last_layer_idx = max(last_layer_idx or 0, layer_num) # Save weights for each layer print("Saving weight files...", flush=True) for layer_num in sorted(layer_dicts.keys()): output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors") save_file(layer_dicts[layer_num], output_file) print(f"Saved layer {layer_num} to {output_file}") # Save final layer (including special weights) last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors") last_layer_dict = layer_dicts.get(last_layer_idx, {}) for key in special_weights: if key in merged_state_dict and merged_state_dict[key] is not None: last_layer_dict[key] = merged_state_dict[key] save_file(last_layer_dict, last_layer_file) print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True) # Generate new index new_index = {"metadata": {"total_size": total_size}, "weight_map": {}} for layer_num in sorted(layer_dicts.keys()): file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors" for key in layer_dicts[layer_num]: new_index["weight_map"][key] = file_name for key in special_weights: if key in merged_state_dict and merged_state_dict[key] is not None: new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors" with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f: json.dump(new_index, f, indent=2) print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True)