sdxl_vae / convert_a1111_asymm.py
recoilme's picture
asymmetric
30358db
import torch
from diffusers import AsymmetricAutoencoderKL
from safetensors.torch import save_file
# Маппинг ключей Diffusers -> A1111
KEY_MAP = {
# Encoder (без изменений)
"encoder.conv_in": "encoder.conv_in",
"encoder.conv_norm_out": "encoder.norm_out",
"encoder.conv_out": "encoder.conv_out",
# Encoder blocks (без изменений)
"encoder.down_blocks.0.resnets.0": "encoder.down.0.block.0",
"encoder.down_blocks.0.resnets.1": "encoder.down.0.block.1",
"encoder.down_blocks.0.downsamplers.0": "encoder.down.0.downsample",
"encoder.down_blocks.1.resnets.0": "encoder.down.1.block.0",
"encoder.down_blocks.1.resnets.1": "encoder.down.1.block.1",
"encoder.down_blocks.1.downsamplers.0": "encoder.down.1.downsample",
"encoder.down_blocks.2.resnets.0": "encoder.down.2.block.0",
"encoder.down_blocks.2.resnets.1": "encoder.down.2.block.1",
"encoder.down_blocks.2.downsamplers.0": "encoder.down.2.downsample",
"encoder.down_blocks.3.resnets.0": "encoder.down.3.block.0",
"encoder.down_blocks.3.resnets.1": "encoder.down.3.block.1",
# Encoder middle
"encoder.mid_block.resnets.0": "encoder.mid.block_1",
"encoder.mid_block.attentions.0": "encoder.mid.attn_1",
"encoder.mid_block.resnets.1": "encoder.mid.block_2",
# Decoder
"decoder.conv_in": "decoder.conv_in",
"decoder.conv_norm_out": "decoder.norm_out",
"decoder.conv_out": "decoder.conv_out",
# Decoder middle
"decoder.mid_block.resnets.0": "decoder.mid.block_1",
"decoder.mid_block.attentions.0": "decoder.mid.attn_1",
"decoder.mid_block.resnets.1": "decoder.mid.block_2",
# Decoder blocks - ИСПРАВЛЕНО для 4 блоков
# up_blocks.0 -> up.3 (самый глубокий)
"decoder.up_blocks.0.resnets.0": "decoder.up.3.block.0",
"decoder.up_blocks.0.resnets.1": "decoder.up.3.block.1",
"decoder.up_blocks.0.resnets.2": "decoder.up.3.block.2",
"decoder.up_blocks.0.resnets.3": "decoder.up.3.block.3",
"decoder.up_blocks.0.upsamplers.0": "decoder.up.3.upsample",
# up_blocks.1 -> up.2
"decoder.up_blocks.1.resnets.0": "decoder.up.2.block.0",
"decoder.up_blocks.1.resnets.1": "decoder.up.2.block.1",
"decoder.up_blocks.1.resnets.2": "decoder.up.2.block.2",
"decoder.up_blocks.1.resnets.3": "decoder.up.2.block.3",
"decoder.up_blocks.1.upsamplers.0": "decoder.up.2.upsample",
# up_blocks.2 -> up.1
"decoder.up_blocks.2.resnets.0": "decoder.up.1.block.0",
"decoder.up_blocks.2.resnets.1": "decoder.up.1.block.1",
"decoder.up_blocks.2.resnets.2": "decoder.up.1.block.2",
"decoder.up_blocks.2.resnets.3": "decoder.up.1.block.3",
"decoder.up_blocks.2.upsamplers.0": "decoder.up.1.upsample",
# up_blocks.3 -> up.0 (самый верхний)
"decoder.up_blocks.3.resnets.0": "decoder.up.0.block.0",
"decoder.up_blocks.3.resnets.1": "decoder.up.0.block.1",
"decoder.up_blocks.3.resnets.2": "decoder.up.0.block.2",
"decoder.up_blocks.3.resnets.3": "decoder.up.0.block.3",
}
# Дополнительные замены для конкретных слоев
LAYER_RENAMES = {
"conv_shortcut": "nin_shortcut",
"group_norm": "norm",
"to_q": "q",
"to_k": "k",
"to_v": "v",
"to_out.0": "proj_out",
}
def convert_key(key):
"""Конвертирует ключ из формата Diffusers в формат A1111"""
# Пропускаем специфичные для AsymmetricVAE компоненты
if "condition_encoder" in key:
return None # A1111 не поддерживает condition_encoder
# Сначала проверяем прямые маппинги
for diffusers_prefix, a1111_prefix in KEY_MAP.items():
if key.startswith(diffusers_prefix):
new_key = key.replace(diffusers_prefix, a1111_prefix, 1)
# Применяем дополнительные замены
for old, new in LAYER_RENAMES.items():
new_key = new_key.replace(old, new)
return new_key
# Если не нашли в маппинге, возвращаем как есть
return key
# Загружаем VAE
vae = AsymmetricAutoencoderKL.from_pretrained("./asymmetric_vae")
state_dict = vae.state_dict()
# Конвертируем ключи
converted_state_dict = {}
skipped_keys = []
for key, value in state_dict.items():
new_key = convert_key(key)
if new_key is None:
skipped_keys.append(key)
continue
# Проверяем, нужно ли изменить форму для attention весов
if "attn_1" in new_key and any(x in new_key for x in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]):
# Преобразуем из [out_features, in_features] в [out_features, in_features, 1, 1]
if value.dim() == 2:
value = value.unsqueeze(-1).unsqueeze(-1)
converted_state_dict[new_key] = value
# Сохраняем
save_file(converted_state_dict, "sdxl_vae_asymm_a1111.safetensors")
print(f"Конвертировано {len(converted_state_dict)} ключей")
print(f"Пропущено {len(skipped_keys)} ключей (condition_encoder и др.)")
if skipped_keys:
print("\nПропущенные ключи:")
for key in skipped_keys[:10]: # Показываем первые 10
print(f" - {key}")
print("\nПримеры конвертированных ключей:")
for i, (old, new) in enumerate(zip(list(state_dict.keys())[:5], list(converted_state_dict.keys())[:5])):
if old not in skipped_keys:
print(f"{old} -> {new}")
# Проверяем attention веса
print("\nAttention веса после конвертации:")
for key, value in converted_state_dict.items():
if "attn_1" in key and "weight" in key:
print(f"{key}: {value.shape}")