File size: 4,825 Bytes
30358db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from diffusers import AutoencoderKL
from safetensors.torch import save_file

# Маппинг ключей Diffusers -> A1111
KEY_MAP = {
    # Encoder
    "encoder.conv_in": "encoder.conv_in",
    "encoder.conv_norm_out": "encoder.norm_out",
    "encoder.conv_out": "encoder.conv_out",
    
    # Encoder blocks
    "encoder.down_blocks.0.resnets.0": "encoder.down.0.block.0",
    "encoder.down_blocks.0.resnets.1": "encoder.down.0.block.1",
    "encoder.down_blocks.0.downsamplers.0": "encoder.down.0.downsample",
    
    "encoder.down_blocks.1.resnets.0": "encoder.down.1.block.0",
    "encoder.down_blocks.1.resnets.1": "encoder.down.1.block.1",
    "encoder.down_blocks.1.downsamplers.0": "encoder.down.1.downsample",
    
    "encoder.down_blocks.2.resnets.0": "encoder.down.2.block.0",
    "encoder.down_blocks.2.resnets.1": "encoder.down.2.block.1",
    "encoder.down_blocks.2.downsamplers.0": "encoder.down.2.downsample",
    
    "encoder.down_blocks.3.resnets.0": "encoder.down.3.block.0",
    "encoder.down_blocks.3.resnets.1": "encoder.down.3.block.1",
    
    # Encoder middle
    "encoder.mid_block.resnets.0": "encoder.mid.block_1",
    "encoder.mid_block.attentions.0": "encoder.mid.attn_1",
    "encoder.mid_block.resnets.1": "encoder.mid.block_2",
    
    # Decoder
    "decoder.conv_in": "decoder.conv_in",
    "decoder.conv_norm_out": "decoder.norm_out",
    "decoder.conv_out": "decoder.conv_out",
    
    # Decoder middle
    "decoder.mid_block.resnets.0": "decoder.mid.block_1",
    "decoder.mid_block.attentions.0": "decoder.mid.attn_1",
    "decoder.mid_block.resnets.1": "decoder.mid.block_2",
    
    # Decoder blocks
    "decoder.up_blocks.0.resnets.0": "decoder.up.3.block.0",
    "decoder.up_blocks.0.resnets.1": "decoder.up.3.block.1",
    "decoder.up_blocks.0.resnets.2": "decoder.up.3.block.2",
    "decoder.up_blocks.0.upsamplers.0": "decoder.up.3.upsample",
    
    "decoder.up_blocks.1.resnets.0": "decoder.up.2.block.0",
    "decoder.up_blocks.1.resnets.1": "decoder.up.2.block.1",
    "decoder.up_blocks.1.resnets.2": "decoder.up.2.block.2",
    "decoder.up_blocks.1.upsamplers.0": "decoder.up.2.upsample",
    
    "decoder.up_blocks.2.resnets.0": "decoder.up.1.block.0",
    "decoder.up_blocks.2.resnets.1": "decoder.up.1.block.1",
    "decoder.up_blocks.2.resnets.2": "decoder.up.1.block.2",
    "decoder.up_blocks.2.upsamplers.0": "decoder.up.1.upsample",
    
    "decoder.up_blocks.3.resnets.0": "decoder.up.0.block.0",
    "decoder.up_blocks.3.resnets.1": "decoder.up.0.block.1",
    "decoder.up_blocks.3.resnets.2": "decoder.up.0.block.2",
}

# Дополнительные замены для конкретных слоев
LAYER_RENAMES = {
    "conv_shortcut": "nin_shortcut",
    "group_norm": "norm",
    "to_q": "q",
    "to_k": "k", 
    "to_v": "v",
    "to_out.0": "proj_out",
}

def convert_key(key):
    """Конвертирует ключ из формата Diffusers в формат A1111"""
    # Сначала проверяем прямые маппинги
    for diffusers_prefix, a1111_prefix in KEY_MAP.items():
        if key.startswith(diffusers_prefix):
            new_key = key.replace(diffusers_prefix, a1111_prefix, 1)
            # Применяем дополнительные замены
            for old, new in LAYER_RENAMES.items():
                new_key = new_key.replace(old, new)
            return new_key
    
    # Если не нашли в маппинге, возвращаем как есть
    return key

# Загружаем VAE
vae = AutoencoderKL.from_pretrained("./vae")
state_dict = vae.state_dict()

# Конвертируем ключи
converted_state_dict = {}
for key, value in state_dict.items():
    new_key = convert_key(key)
    
    # Проверяем, нужно ли изменить форму для attention весов
    if "attn_1" in new_key and any(x in new_key for x in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]):
        # Преобразуем из [out_features, in_features] в [out_features, in_features, 1, 1]
        if value.dim() == 2:
            value = value.unsqueeze(-1).unsqueeze(-1)
    
    converted_state_dict[new_key] = value

# Сохраняем
save_file(converted_state_dict, "sdxl_vae_a1111.safetensors")

print(f"Конвертировано {len(converted_state_dict)} ключей")
print("\nПримеры конвертированных ключей:")
for i, (old, new) in enumerate(zip(list(state_dict.keys())[:5], list(converted_state_dict.keys())[:5])):
    print(f"{old} -> {new}")

# Проверяем attention веса
print("\nAttention веса после конвертации:")
for key, value in converted_state_dict.items():
    if "attn_1" in key and "weight" in key:
        print(f"{key}: {value.shape}")