sdxl_vae / convert_a1111.py
recoilme's picture
asymmetric
30358db
import torch
from diffusers import AutoencoderKL
from safetensors.torch import save_file
# Маппинг ключей Diffusers -> A1111
KEY_MAP = {
# Encoder
"encoder.conv_in": "encoder.conv_in",
"encoder.conv_norm_out": "encoder.norm_out",
"encoder.conv_out": "encoder.conv_out",
# Encoder blocks
"encoder.down_blocks.0.resnets.0": "encoder.down.0.block.0",
"encoder.down_blocks.0.resnets.1": "encoder.down.0.block.1",
"encoder.down_blocks.0.downsamplers.0": "encoder.down.0.downsample",
"encoder.down_blocks.1.resnets.0": "encoder.down.1.block.0",
"encoder.down_blocks.1.resnets.1": "encoder.down.1.block.1",
"encoder.down_blocks.1.downsamplers.0": "encoder.down.1.downsample",
"encoder.down_blocks.2.resnets.0": "encoder.down.2.block.0",
"encoder.down_blocks.2.resnets.1": "encoder.down.2.block.1",
"encoder.down_blocks.2.downsamplers.0": "encoder.down.2.downsample",
"encoder.down_blocks.3.resnets.0": "encoder.down.3.block.0",
"encoder.down_blocks.3.resnets.1": "encoder.down.3.block.1",
# Encoder middle
"encoder.mid_block.resnets.0": "encoder.mid.block_1",
"encoder.mid_block.attentions.0": "encoder.mid.attn_1",
"encoder.mid_block.resnets.1": "encoder.mid.block_2",
# Decoder
"decoder.conv_in": "decoder.conv_in",
"decoder.conv_norm_out": "decoder.norm_out",
"decoder.conv_out": "decoder.conv_out",
# Decoder middle
"decoder.mid_block.resnets.0": "decoder.mid.block_1",
"decoder.mid_block.attentions.0": "decoder.mid.attn_1",
"decoder.mid_block.resnets.1": "decoder.mid.block_2",
# Decoder blocks
"decoder.up_blocks.0.resnets.0": "decoder.up.3.block.0",
"decoder.up_blocks.0.resnets.1": "decoder.up.3.block.1",
"decoder.up_blocks.0.resnets.2": "decoder.up.3.block.2",
"decoder.up_blocks.0.upsamplers.0": "decoder.up.3.upsample",
"decoder.up_blocks.1.resnets.0": "decoder.up.2.block.0",
"decoder.up_blocks.1.resnets.1": "decoder.up.2.block.1",
"decoder.up_blocks.1.resnets.2": "decoder.up.2.block.2",
"decoder.up_blocks.1.upsamplers.0": "decoder.up.2.upsample",
"decoder.up_blocks.2.resnets.0": "decoder.up.1.block.0",
"decoder.up_blocks.2.resnets.1": "decoder.up.1.block.1",
"decoder.up_blocks.2.resnets.2": "decoder.up.1.block.2",
"decoder.up_blocks.2.upsamplers.0": "decoder.up.1.upsample",
"decoder.up_blocks.3.resnets.0": "decoder.up.0.block.0",
"decoder.up_blocks.3.resnets.1": "decoder.up.0.block.1",
"decoder.up_blocks.3.resnets.2": "decoder.up.0.block.2",
}
# Дополнительные замены для конкретных слоев
LAYER_RENAMES = {
"conv_shortcut": "nin_shortcut",
"group_norm": "norm",
"to_q": "q",
"to_k": "k",
"to_v": "v",
"to_out.0": "proj_out",
}
def convert_key(key):
"""Конвертирует ключ из формата Diffusers в формат A1111"""
# Сначала проверяем прямые маппинги
for diffusers_prefix, a1111_prefix in KEY_MAP.items():
if key.startswith(diffusers_prefix):
new_key = key.replace(diffusers_prefix, a1111_prefix, 1)
# Применяем дополнительные замены
for old, new in LAYER_RENAMES.items():
new_key = new_key.replace(old, new)
return new_key
# Если не нашли в маппинге, возвращаем как есть
return key
# Загружаем VAE
vae = AutoencoderKL.from_pretrained("./vae")
state_dict = vae.state_dict()
# Конвертируем ключи
converted_state_dict = {}
for key, value in state_dict.items():
new_key = convert_key(key)
# Проверяем, нужно ли изменить форму для attention весов
if "attn_1" in new_key and any(x in new_key for x in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]):
# Преобразуем из [out_features, in_features] в [out_features, in_features, 1, 1]
if value.dim() == 2:
value = value.unsqueeze(-1).unsqueeze(-1)
converted_state_dict[new_key] = value
# Сохраняем
save_file(converted_state_dict, "sdxl_vae_a1111.safetensors")
print(f"Конвертировано {len(converted_state_dict)} ключей")
print("\nПримеры конвертированных ключей:")
for i, (old, new) in enumerate(zip(list(state_dict.keys())[:5], list(converted_state_dict.keys())[:5])):
print(f"{old} -> {new}")
# Проверяем attention веса
print("\nAttention веса после конвертации:")
for key, value in converted_state_dict.items():
if "attn_1" in key and "weight" in key:
print(f"{key}: {value.shape}")