huihui-ai/grok-2
This Python script is designed to process and merge sharded weight files (in safetensors format) for a machine learning model, specifically targeting the xai-org/grok-2 model. The main functionalities include:
Just a simple merge, without any inference code, and does not indicate whether the final model is reasonable or correct.
Now, do we need a custom MixtralForCausalLM?
- Collecting safetensors files: Locates all
pytorch_model-*.safetensors
files in the specified model directory. - Loading files into cache: Loads all safetensors files into memory and builds a key-to-file mapping.
- Merging Tensor Parallel (TP) shards: Merges shards for tensor parallelism (TP=8) along specific dimensions and verifies the merged tensor shapes.
- Grouping weights by layer: Organizes weights by model layer, with special weights (e.g.,
lm_head.weight
,model.embed_tokens.weight
, andmodel.norm.weight
) handled separately. - Saving merged weights: Saves the grouped weights as new safetensors files and generates a new index file pytorch_model.bin.index.json.
Features
- Input: Safetensors files in the
xai-org/grok-2
model directory. - Output: Layer-organized safetensors files and an index file in the
huihui-ai/grok-2
directory. - Tensor Parallelism Support: Handles TP=8 shards, merging tensors along specific dimensions (
w1.weight
andw3.weight
along dim=0,w2.weight
along dim=1). - Error Handling: Includes warnings and handling for missing files, shape mismatches, and other exceptions.
- Shape Validation: Verifies shapes for specific weights (e.g., MoE layer weights), ensuring merged tensors match expected shapes (e.g.,
(16384, 8192)
or(8192, 16384)
).
Usage
- Install the required Python libraries:
pip install torch safetensors
- Place the script in an environment with the
xai-org/grok-2
model directory. - Run the script:
python convert_safetensors.py
- Output files will be saved in the
huihui-ai/grok-2
directory, including layer-organized safetensors files and an index file.
Notes
- Ensure the input directory
xai-org/grok-2
contains validpytorch_model-*.safetensors
files. - The script assumes a tensor parallelism degree of 8 (
tp_count = 8
). Modify thetp_count
value in the script if needed. - Memory requirements may be high; run on a machine with sufficient memory.
- If shards are missing or shapes mismatch, the script will print warnings and attempt to proceed.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for huihui-ai/grok-2
Base model
xai-org/grok-2