|
from safetensors.torch import load_file, save_file |
|
import torch |
|
|
|
|
|
wrong_file_path = "model-00001-of-00006.safetensors" |
|
correct_file_path = "/root/autodl-tmp/HF/hub/models--mistralai--Mistral-Small-3.2-24B-Instruct-2506/snapshots/1483f238dc0527ce28022b0c5252515b2552334c/model-00001-of-00010.safetensors" |
|
output_file_path = "model-00001-of-00006.safetensors" |
|
|
|
|
|
wrong_weights = load_file(wrong_file_path) |
|
correct_weights = load_file(correct_file_path) |
|
|
|
|
|
keys_to_remove = [k for k in wrong_weights.keys() if k.startswith("multi_modal_projector.") and ( |
|
k.endswith("weight_packed") or k.endswith("weight_scale") or k.endswith("weight_shape") |
|
)] |
|
|
|
|
|
for key in keys_to_remove: |
|
print(f"❌ Remove: {key}") |
|
del wrong_weights[key] |
|
|
|
|
|
keys_to_restore = [ |
|
"multi_modal_projector.linear_1.weight", |
|
"multi_modal_projector.linear_2.weight", |
|
"multi_modal_projector.norm.weight", |
|
"multi_modal_projector.patch_merger.merging_layer.weight" |
|
] |
|
|
|
|
|
for key in keys_to_restore: |
|
print(f"✅ Restore: {key}") |
|
wrong_weights[key] = correct_weights[key] |
|
|
|
|
|
save_file(wrong_weights, output_file_path) |
|
print(f"✅ 修复后的权重保存至: {output_file_path}") |
|
|