|
import torch |
|
from diffusers import AutoencoderKL |
|
from safetensors.torch import save_file |
|
|
|
|
|
KEY_MAP = { |
|
|
|
"encoder.conv_in": "encoder.conv_in", |
|
"encoder.conv_norm_out": "encoder.norm_out", |
|
"encoder.conv_out": "encoder.conv_out", |
|
|
|
|
|
"encoder.down_blocks.0.resnets.0": "encoder.down.0.block.0", |
|
"encoder.down_blocks.0.resnets.1": "encoder.down.0.block.1", |
|
"encoder.down_blocks.0.downsamplers.0": "encoder.down.0.downsample", |
|
|
|
"encoder.down_blocks.1.resnets.0": "encoder.down.1.block.0", |
|
"encoder.down_blocks.1.resnets.1": "encoder.down.1.block.1", |
|
"encoder.down_blocks.1.downsamplers.0": "encoder.down.1.downsample", |
|
|
|
"encoder.down_blocks.2.resnets.0": "encoder.down.2.block.0", |
|
"encoder.down_blocks.2.resnets.1": "encoder.down.2.block.1", |
|
"encoder.down_blocks.2.downsamplers.0": "encoder.down.2.downsample", |
|
|
|
"encoder.down_blocks.3.resnets.0": "encoder.down.3.block.0", |
|
"encoder.down_blocks.3.resnets.1": "encoder.down.3.block.1", |
|
|
|
|
|
"encoder.mid_block.resnets.0": "encoder.mid.block_1", |
|
"encoder.mid_block.attentions.0": "encoder.mid.attn_1", |
|
"encoder.mid_block.resnets.1": "encoder.mid.block_2", |
|
|
|
|
|
"decoder.conv_in": "decoder.conv_in", |
|
"decoder.conv_norm_out": "decoder.norm_out", |
|
"decoder.conv_out": "decoder.conv_out", |
|
|
|
|
|
"decoder.mid_block.resnets.0": "decoder.mid.block_1", |
|
"decoder.mid_block.attentions.0": "decoder.mid.attn_1", |
|
"decoder.mid_block.resnets.1": "decoder.mid.block_2", |
|
|
|
|
|
"decoder.up_blocks.0.resnets.0": "decoder.up.3.block.0", |
|
"decoder.up_blocks.0.resnets.1": "decoder.up.3.block.1", |
|
"decoder.up_blocks.0.resnets.2": "decoder.up.3.block.2", |
|
"decoder.up_blocks.0.upsamplers.0": "decoder.up.3.upsample", |
|
|
|
"decoder.up_blocks.1.resnets.0": "decoder.up.2.block.0", |
|
"decoder.up_blocks.1.resnets.1": "decoder.up.2.block.1", |
|
"decoder.up_blocks.1.resnets.2": "decoder.up.2.block.2", |
|
"decoder.up_blocks.1.upsamplers.0": "decoder.up.2.upsample", |
|
|
|
"decoder.up_blocks.2.resnets.0": "decoder.up.1.block.0", |
|
"decoder.up_blocks.2.resnets.1": "decoder.up.1.block.1", |
|
"decoder.up_blocks.2.resnets.2": "decoder.up.1.block.2", |
|
"decoder.up_blocks.2.upsamplers.0": "decoder.up.1.upsample", |
|
|
|
"decoder.up_blocks.3.resnets.0": "decoder.up.0.block.0", |
|
"decoder.up_blocks.3.resnets.1": "decoder.up.0.block.1", |
|
"decoder.up_blocks.3.resnets.2": "decoder.up.0.block.2", |
|
} |
|
|
|
|
|
LAYER_RENAMES = { |
|
"conv_shortcut": "nin_shortcut", |
|
"group_norm": "norm", |
|
"to_q": "q", |
|
"to_k": "k", |
|
"to_v": "v", |
|
"to_out.0": "proj_out", |
|
} |
|
|
|
def convert_key(key): |
|
"""Конвертирует ключ из формата Diffusers в формат A1111""" |
|
|
|
for diffusers_prefix, a1111_prefix in KEY_MAP.items(): |
|
if key.startswith(diffusers_prefix): |
|
new_key = key.replace(diffusers_prefix, a1111_prefix, 1) |
|
|
|
for old, new in LAYER_RENAMES.items(): |
|
new_key = new_key.replace(old, new) |
|
return new_key |
|
|
|
|
|
return key |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained("./vae") |
|
state_dict = vae.state_dict() |
|
|
|
|
|
converted_state_dict = {} |
|
for key, value in state_dict.items(): |
|
new_key = convert_key(key) |
|
|
|
|
|
if "attn_1" in new_key and any(x in new_key for x in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]): |
|
|
|
if value.dim() == 2: |
|
value = value.unsqueeze(-1).unsqueeze(-1) |
|
|
|
converted_state_dict[new_key] = value |
|
|
|
|
|
save_file(converted_state_dict, "sdxl_vae_a1111.safetensors") |
|
|
|
print(f"Конвертировано {len(converted_state_dict)} ключей") |
|
print("\nПримеры конвертированных ключей:") |
|
for i, (old, new) in enumerate(zip(list(state_dict.keys())[:5], list(converted_state_dict.keys())[:5])): |
|
print(f"{old} -> {new}") |
|
|
|
|
|
print("\nAttention веса после конвертации:") |
|
for key, value in converted_state_dict.items(): |
|
if "attn_1" in key and "weight" in key: |
|
print(f"{key}: {value.shape}") |