recoilme commited on
Commit
30358db
·
1 Parent(s): 24b99af

asymmetric

Browse files
README.md CHANGED
@@ -14,12 +14,18 @@ library_name: diffusers
14
  |----------------------------|-------------|-----------|------------|
15
  | madebyollin/sdxl-vae-fp16-fix | 3.680e-03 | 25.2100 | 0.1314 |
16
  | KBlueLeaf/EQ-SDXL-VAE | 3.530e-03 | 25.2827 | 0.1298 |
17
- | **AiArtLab/sdxl_vae** | <span style="color:red">**3.321e-03**</span> | <span style="color:red">**25.6389**</span> | <span style="color:red">**0.1251**</span> |
 
 
 
 
 
 
18
 
19
 
20
  ### Train status, in progress:
21
 
22
- ![result](result.png)
23
 
24
  ## VAE Training Process
25
 
@@ -47,6 +53,8 @@ library_name: diffusers
47
 
48
  ## Compare
49
 
 
 
50
  https://imgsli.com/NDA3Njgw/2/3
51
 
52
  ## Donations
 
14
  |----------------------------|-------------|-----------|------------|
15
  | madebyollin/sdxl-vae-fp16-fix | 3.680e-03 | 25.2100 | 0.1314 |
16
  | KBlueLeaf/EQ-SDXL-VAE | 3.530e-03 | 25.2827 | 0.1298 |
17
+ | **AiArtLab/sdxl_vae** | **3.321e-03** | **25.6389** | **0.1251** |
18
+
19
+
20
+ [![Click it](vae.png)](https://imgsli.com/NDA3OTgz)
21
+
22
+
23
+ ![zooomed](result.png)
24
 
25
 
26
  ### Train status, in progress:
27
 
28
+ We are currently testing the possibility of improving the SDXL VAE decoder by increasing its depth (asymmetric VAE). This will lead to a slight increase in model size (approximately 20 percent), but we expect this will improve reconstruction quality without modifying the encoder (does not require retraining SDXL). Unfortunately, our resources are quite limited (we train models on consumer GPUs, currently training three models: SDXL VAE, Simple Diffusion, and Simple VAE), so please be patient. Model training is a meticulous and time-consuming process.
29
 
30
  ## VAE Training Process
31
 
 
53
 
54
  ## Compare
55
 
56
+ https://imgsli.com/NDA3OTgz
57
+
58
  https://imgsli.com/NDA3Njgw/2/3
59
 
60
  ## Donations
asymmetric_vae/config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AsymmetricAutoencoderKL",
3
+ "_diffusers_version": "0.34.0",
4
+ "_name_or_path": "asymmetric_vae_empty",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_out_channels": [
13
+ 128,
14
+ 256,
15
+ 512,
16
+ 512
17
+ ],
18
+ "down_block_types": [
19
+ "DownEncoderBlock2D",
20
+ "DownEncoderBlock2D",
21
+ "DownEncoderBlock2D",
22
+ "DownEncoderBlock2D"
23
+ ],
24
+ "force_upcast": false,
25
+ "in_channels": 3,
26
+ "latent_channels": 4,
27
+ "layers_per_down_block": 2,
28
+ "layers_per_up_block": 3,
29
+ "norm_num_groups": 32,
30
+ "out_channels": 3,
31
+ "sample_size": 1024,
32
+ "scaling_factor": 0.13025,
33
+ "up_block_out_channels": [
34
+ 128,
35
+ 256,
36
+ 512,
37
+ 512
38
+ ],
39
+ "up_block_types": [
40
+ "UpDecoderBlock2D",
41
+ "UpDecoderBlock2D",
42
+ "UpDecoderBlock2D",
43
+ "UpDecoderBlock2D"
44
+ ]
45
+ }
asymmetric_vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8de44e4f21835eb457785a63f7e96c7ddba34b9b812bdeee79012d8bd0dae199
3
+ size 421473052
asymmetric_vae_new/config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AsymmetricAutoencoderKL",
3
+ "_diffusers_version": "0.34.0",
4
+ "_name_or_path": "asymmetric_vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_out_channels": [
13
+ 128,
14
+ 256,
15
+ 512,
16
+ 512
17
+ ],
18
+ "down_block_types": [
19
+ "DownEncoderBlock2D",
20
+ "DownEncoderBlock2D",
21
+ "DownEncoderBlock2D",
22
+ "DownEncoderBlock2D"
23
+ ],
24
+ "force_upcast": false,
25
+ "in_channels": 3,
26
+ "latent_channels": 4,
27
+ "layers_per_down_block": 2,
28
+ "layers_per_up_block": 3,
29
+ "norm_num_groups": 32,
30
+ "out_channels": 3,
31
+ "sample_size": 1024,
32
+ "scaling_factor": 0.13025,
33
+ "up_block_out_channels": [
34
+ 128,
35
+ 256,
36
+ 512,
37
+ 512
38
+ ],
39
+ "up_block_types": [
40
+ "UpDecoderBlock2D",
41
+ "UpDecoderBlock2D",
42
+ "UpDecoderBlock2D",
43
+ "UpDecoderBlock2D"
44
+ ]
45
+ }
asymmetric_vae_new/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b0689cd2f3a6f81c14a95e1f2a7c4cee6b97b51f34700c5983ee2f28df17ef6
3
+ size 421473052
convert_a1111.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import AutoencoderKL
3
+ from safetensors.torch import save_file
4
+
5
+ # Маппинг ключей Diffusers -> A1111
6
+ KEY_MAP = {
7
+ # Encoder
8
+ "encoder.conv_in": "encoder.conv_in",
9
+ "encoder.conv_norm_out": "encoder.norm_out",
10
+ "encoder.conv_out": "encoder.conv_out",
11
+
12
+ # Encoder blocks
13
+ "encoder.down_blocks.0.resnets.0": "encoder.down.0.block.0",
14
+ "encoder.down_blocks.0.resnets.1": "encoder.down.0.block.1",
15
+ "encoder.down_blocks.0.downsamplers.0": "encoder.down.0.downsample",
16
+
17
+ "encoder.down_blocks.1.resnets.0": "encoder.down.1.block.0",
18
+ "encoder.down_blocks.1.resnets.1": "encoder.down.1.block.1",
19
+ "encoder.down_blocks.1.downsamplers.0": "encoder.down.1.downsample",
20
+
21
+ "encoder.down_blocks.2.resnets.0": "encoder.down.2.block.0",
22
+ "encoder.down_blocks.2.resnets.1": "encoder.down.2.block.1",
23
+ "encoder.down_blocks.2.downsamplers.0": "encoder.down.2.downsample",
24
+
25
+ "encoder.down_blocks.3.resnets.0": "encoder.down.3.block.0",
26
+ "encoder.down_blocks.3.resnets.1": "encoder.down.3.block.1",
27
+
28
+ # Encoder middle
29
+ "encoder.mid_block.resnets.0": "encoder.mid.block_1",
30
+ "encoder.mid_block.attentions.0": "encoder.mid.attn_1",
31
+ "encoder.mid_block.resnets.1": "encoder.mid.block_2",
32
+
33
+ # Decoder
34
+ "decoder.conv_in": "decoder.conv_in",
35
+ "decoder.conv_norm_out": "decoder.norm_out",
36
+ "decoder.conv_out": "decoder.conv_out",
37
+
38
+ # Decoder middle
39
+ "decoder.mid_block.resnets.0": "decoder.mid.block_1",
40
+ "decoder.mid_block.attentions.0": "decoder.mid.attn_1",
41
+ "decoder.mid_block.resnets.1": "decoder.mid.block_2",
42
+
43
+ # Decoder blocks
44
+ "decoder.up_blocks.0.resnets.0": "decoder.up.3.block.0",
45
+ "decoder.up_blocks.0.resnets.1": "decoder.up.3.block.1",
46
+ "decoder.up_blocks.0.resnets.2": "decoder.up.3.block.2",
47
+ "decoder.up_blocks.0.upsamplers.0": "decoder.up.3.upsample",
48
+
49
+ "decoder.up_blocks.1.resnets.0": "decoder.up.2.block.0",
50
+ "decoder.up_blocks.1.resnets.1": "decoder.up.2.block.1",
51
+ "decoder.up_blocks.1.resnets.2": "decoder.up.2.block.2",
52
+ "decoder.up_blocks.1.upsamplers.0": "decoder.up.2.upsample",
53
+
54
+ "decoder.up_blocks.2.resnets.0": "decoder.up.1.block.0",
55
+ "decoder.up_blocks.2.resnets.1": "decoder.up.1.block.1",
56
+ "decoder.up_blocks.2.resnets.2": "decoder.up.1.block.2",
57
+ "decoder.up_blocks.2.upsamplers.0": "decoder.up.1.upsample",
58
+
59
+ "decoder.up_blocks.3.resnets.0": "decoder.up.0.block.0",
60
+ "decoder.up_blocks.3.resnets.1": "decoder.up.0.block.1",
61
+ "decoder.up_blocks.3.resnets.2": "decoder.up.0.block.2",
62
+ }
63
+
64
+ # Дополнительные замены для конкретных слоев
65
+ LAYER_RENAMES = {
66
+ "conv_shortcut": "nin_shortcut",
67
+ "group_norm": "norm",
68
+ "to_q": "q",
69
+ "to_k": "k",
70
+ "to_v": "v",
71
+ "to_out.0": "proj_out",
72
+ }
73
+
74
+ def convert_key(key):
75
+ """Конвертирует ключ из формата Diffusers в формат A1111"""
76
+ # Сначала проверяем прямые маппинги
77
+ for diffusers_prefix, a1111_prefix in KEY_MAP.items():
78
+ if key.startswith(diffusers_prefix):
79
+ new_key = key.replace(diffusers_prefix, a1111_prefix, 1)
80
+ # Применяем дополнительные замены
81
+ for old, new in LAYER_RENAMES.items():
82
+ new_key = new_key.replace(old, new)
83
+ return new_key
84
+
85
+ # Если не нашли в маппинге, возвращаем как есть
86
+ return key
87
+
88
+ # Загружаем VAE
89
+ vae = AutoencoderKL.from_pretrained("./vae")
90
+ state_dict = vae.state_dict()
91
+
92
+ # Конвертируем ключи
93
+ converted_state_dict = {}
94
+ for key, value in state_dict.items():
95
+ new_key = convert_key(key)
96
+
97
+ # Проверяем, нужно ли изменить форму для attention весов
98
+ if "attn_1" in new_key and any(x in new_key for x in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]):
99
+ # Преобразуем из [out_features, in_features] в [out_features, in_features, 1, 1]
100
+ if value.dim() == 2:
101
+ value = value.unsqueeze(-1).unsqueeze(-1)
102
+
103
+ converted_state_dict[new_key] = value
104
+
105
+ # Сохраняем
106
+ save_file(converted_state_dict, "sdxl_vae_a1111.safetensors")
107
+
108
+ print(f"Конвертировано {len(converted_state_dict)} ключей")
109
+ print("\nПримеры конвертированных ключей:")
110
+ for i, (old, new) in enumerate(zip(list(state_dict.keys())[:5], list(converted_state_dict.keys())[:5])):
111
+ print(f"{old} -> {new}")
112
+
113
+ # Проверяем attention веса
114
+ print("\nAttention веса после конвертации:")
115
+ for key, value in converted_state_dict.items():
116
+ if "attn_1" in key and "weight" in key:
117
+ print(f"{key}: {value.shape}")
convert_a1111_asymm.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import AsymmetricAutoencoderKL
3
+ from safetensors.torch import save_file
4
+
5
+ # Маппинг ключей Diffusers -> A1111
6
+ KEY_MAP = {
7
+ # Encoder (без изменений)
8
+ "encoder.conv_in": "encoder.conv_in",
9
+ "encoder.conv_norm_out": "encoder.norm_out",
10
+ "encoder.conv_out": "encoder.conv_out",
11
+
12
+ # Encoder blocks (без изменений)
13
+ "encoder.down_blocks.0.resnets.0": "encoder.down.0.block.0",
14
+ "encoder.down_blocks.0.resnets.1": "encoder.down.0.block.1",
15
+ "encoder.down_blocks.0.downsamplers.0": "encoder.down.0.downsample",
16
+
17
+ "encoder.down_blocks.1.resnets.0": "encoder.down.1.block.0",
18
+ "encoder.down_blocks.1.resnets.1": "encoder.down.1.block.1",
19
+ "encoder.down_blocks.1.downsamplers.0": "encoder.down.1.downsample",
20
+
21
+ "encoder.down_blocks.2.resnets.0": "encoder.down.2.block.0",
22
+ "encoder.down_blocks.2.resnets.1": "encoder.down.2.block.1",
23
+ "encoder.down_blocks.2.downsamplers.0": "encoder.down.2.downsample",
24
+
25
+ "encoder.down_blocks.3.resnets.0": "encoder.down.3.block.0",
26
+ "encoder.down_blocks.3.resnets.1": "encoder.down.3.block.1",
27
+
28
+ # Encoder middle
29
+ "encoder.mid_block.resnets.0": "encoder.mid.block_1",
30
+ "encoder.mid_block.attentions.0": "encoder.mid.attn_1",
31
+ "encoder.mid_block.resnets.1": "encoder.mid.block_2",
32
+
33
+ # Decoder
34
+ "decoder.conv_in": "decoder.conv_in",
35
+ "decoder.conv_norm_out": "decoder.norm_out",
36
+ "decoder.conv_out": "decoder.conv_out",
37
+
38
+ # Decoder middle
39
+ "decoder.mid_block.resnets.0": "decoder.mid.block_1",
40
+ "decoder.mid_block.attentions.0": "decoder.mid.attn_1",
41
+ "decoder.mid_block.resnets.1": "decoder.mid.block_2",
42
+
43
+ # Decoder blocks - ИСПРАВЛЕНО для 4 блоков
44
+ # up_blocks.0 -> up.3 (самый глубокий)
45
+ "decoder.up_blocks.0.resnets.0": "decoder.up.3.block.0",
46
+ "decoder.up_blocks.0.resnets.1": "decoder.up.3.block.1",
47
+ "decoder.up_blocks.0.resnets.2": "decoder.up.3.block.2",
48
+ "decoder.up_blocks.0.resnets.3": "decoder.up.3.block.3",
49
+ "decoder.up_blocks.0.upsamplers.0": "decoder.up.3.upsample",
50
+
51
+ # up_blocks.1 -> up.2
52
+ "decoder.up_blocks.1.resnets.0": "decoder.up.2.block.0",
53
+ "decoder.up_blocks.1.resnets.1": "decoder.up.2.block.1",
54
+ "decoder.up_blocks.1.resnets.2": "decoder.up.2.block.2",
55
+ "decoder.up_blocks.1.resnets.3": "decoder.up.2.block.3",
56
+ "decoder.up_blocks.1.upsamplers.0": "decoder.up.2.upsample",
57
+
58
+ # up_blocks.2 -> up.1
59
+ "decoder.up_blocks.2.resnets.0": "decoder.up.1.block.0",
60
+ "decoder.up_blocks.2.resnets.1": "decoder.up.1.block.1",
61
+ "decoder.up_blocks.2.resnets.2": "decoder.up.1.block.2",
62
+ "decoder.up_blocks.2.resnets.3": "decoder.up.1.block.3",
63
+ "decoder.up_blocks.2.upsamplers.0": "decoder.up.1.upsample",
64
+
65
+ # up_blocks.3 -> up.0 (самый верхний)
66
+ "decoder.up_blocks.3.resnets.0": "decoder.up.0.block.0",
67
+ "decoder.up_blocks.3.resnets.1": "decoder.up.0.block.1",
68
+ "decoder.up_blocks.3.resnets.2": "decoder.up.0.block.2",
69
+ "decoder.up_blocks.3.resnets.3": "decoder.up.0.block.3",
70
+ }
71
+
72
+ # Дополнительные замены для конкретных слоев
73
+ LAYER_RENAMES = {
74
+ "conv_shortcut": "nin_shortcut",
75
+ "group_norm": "norm",
76
+ "to_q": "q",
77
+ "to_k": "k",
78
+ "to_v": "v",
79
+ "to_out.0": "proj_out",
80
+ }
81
+
82
+ def convert_key(key):
83
+ """Конвертирует ключ из формата Diffusers в формат A1111"""
84
+ # Пропускаем специфичные для AsymmetricVAE компоненты
85
+ if "condition_encoder" in key:
86
+ return None # A1111 не поддерживает condition_encoder
87
+
88
+ # Сначала проверяем прямые маппинги
89
+ for diffusers_prefix, a1111_prefix in KEY_MAP.items():
90
+ if key.startswith(diffusers_prefix):
91
+ new_key = key.replace(diffusers_prefix, a1111_prefix, 1)
92
+ # Применяем дополнительные замены
93
+ for old, new in LAYER_RENAMES.items():
94
+ new_key = new_key.replace(old, new)
95
+ return new_key
96
+
97
+ # Если не нашли в маппинге, возвращаем как есть
98
+ return key
99
+
100
+ # Загружаем VAE
101
+ vae = AsymmetricAutoencoderKL.from_pretrained("./asymmetric_vae")
102
+ state_dict = vae.state_dict()
103
+
104
+ # Конвертируем ключи
105
+ converted_state_dict = {}
106
+ skipped_keys = []
107
+
108
+ for key, value in state_dict.items():
109
+ new_key = convert_key(key)
110
+
111
+ if new_key is None:
112
+ skipped_keys.append(key)
113
+ continue
114
+
115
+ # Проверяем, нужно ли изменить форму для attention весов
116
+ if "attn_1" in new_key and any(x in new_key for x in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]):
117
+ # Преобразуем из [out_features, in_features] в [out_features, in_features, 1, 1]
118
+ if value.dim() == 2:
119
+ value = value.unsqueeze(-1).unsqueeze(-1)
120
+
121
+ converted_state_dict[new_key] = value
122
+
123
+ # Сохраняем
124
+ save_file(converted_state_dict, "sdxl_vae_asymm_a1111.safetensors")
125
+
126
+ print(f"Конвертировано {len(converted_state_dict)} ключей")
127
+ print(f"Пропущено {len(skipped_keys)} ключей (condition_encoder и др.)")
128
+
129
+ if skipped_keys:
130
+ print("\nПропущенные ключи:")
131
+ for key in skipped_keys[:10]: # Показываем первые 10
132
+ print(f" - {key}")
133
+
134
+ print("\nПримеры конвертированных ключей:")
135
+ for i, (old, new) in enumerate(zip(list(state_dict.keys())[:5], list(converted_state_dict.keys())[:5])):
136
+ if old not in skipped_keys:
137
+ print(f"{old} -> {new}")
138
+
139
+ # Проверяем attention веса
140
+ print("\nAttention веса после конвертации:")
141
+ for key, value in converted_state_dict.items():
142
+ if "attn_1" in key and "weight" in key:
143
+ print(f"{key}: {value.shape}")
create_asymmetric.ipynb ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "407171be-ab46-442b-a0bd-83ca75173eba",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "AsymmetricAutoencoderKL(\n",
14
+ " (encoder): Encoder(\n",
15
+ " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
16
+ " (down_blocks): ModuleList(\n",
17
+ " (0): DownEncoderBlock2D(\n",
18
+ " (resnets): ModuleList(\n",
19
+ " (0-1): 2 x ResnetBlock2D(\n",
20
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
21
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
22
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
23
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
24
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
25
+ " (nonlinearity): SiLU()\n",
26
+ " )\n",
27
+ " )\n",
28
+ " (downsamplers): ModuleList(\n",
29
+ " (0): Downsample2D(\n",
30
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
31
+ " )\n",
32
+ " )\n",
33
+ " )\n",
34
+ " (1): DownEncoderBlock2D(\n",
35
+ " (resnets): ModuleList(\n",
36
+ " (0): ResnetBlock2D(\n",
37
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
38
+ " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
39
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
40
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
41
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
42
+ " (nonlinearity): SiLU()\n",
43
+ " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
44
+ " )\n",
45
+ " (1): ResnetBlock2D(\n",
46
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
47
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
48
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
49
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
50
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
51
+ " (nonlinearity): SiLU()\n",
52
+ " )\n",
53
+ " )\n",
54
+ " (downsamplers): ModuleList(\n",
55
+ " (0): Downsample2D(\n",
56
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
57
+ " )\n",
58
+ " )\n",
59
+ " )\n",
60
+ " (2): DownEncoderBlock2D(\n",
61
+ " (resnets): ModuleList(\n",
62
+ " (0): ResnetBlock2D(\n",
63
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
64
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
65
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
66
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
67
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
68
+ " (nonlinearity): SiLU()\n",
69
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
70
+ " )\n",
71
+ " (1): ResnetBlock2D(\n",
72
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
73
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
74
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
75
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
76
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
77
+ " (nonlinearity): SiLU()\n",
78
+ " )\n",
79
+ " )\n",
80
+ " (downsamplers): ModuleList(\n",
81
+ " (0): Downsample2D(\n",
82
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n",
83
+ " )\n",
84
+ " )\n",
85
+ " )\n",
86
+ " (3): DownEncoderBlock2D(\n",
87
+ " (resnets): ModuleList(\n",
88
+ " (0-1): 2 x ResnetBlock2D(\n",
89
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
90
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
91
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
92
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
93
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
94
+ " (nonlinearity): SiLU()\n",
95
+ " )\n",
96
+ " )\n",
97
+ " )\n",
98
+ " )\n",
99
+ " (mid_block): UNetMidBlock2D(\n",
100
+ " (attentions): ModuleList(\n",
101
+ " (0): Attention(\n",
102
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
103
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
104
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
105
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
106
+ " (to_out): ModuleList(\n",
107
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
108
+ " (1): Dropout(p=0.0, inplace=False)\n",
109
+ " )\n",
110
+ " )\n",
111
+ " )\n",
112
+ " (resnets): ModuleList(\n",
113
+ " (0-1): 2 x ResnetBlock2D(\n",
114
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
115
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
116
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
117
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
118
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
119
+ " (nonlinearity): SiLU()\n",
120
+ " )\n",
121
+ " )\n",
122
+ " )\n",
123
+ " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
124
+ " (conv_act): SiLU()\n",
125
+ " (conv_out): Conv2d(512, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
126
+ " )\n",
127
+ " (decoder): MaskConditionDecoder(\n",
128
+ " (conv_in): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
129
+ " (up_blocks): ModuleList(\n",
130
+ " (0-1): 2 x UpDecoderBlock2D(\n",
131
+ " (resnets): ModuleList(\n",
132
+ " (0-3): 4 x ResnetBlock2D(\n",
133
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
134
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
135
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
136
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
137
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
138
+ " (nonlinearity): SiLU()\n",
139
+ " )\n",
140
+ " )\n",
141
+ " (upsamplers): ModuleList(\n",
142
+ " (0): Upsample2D(\n",
143
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
144
+ " )\n",
145
+ " )\n",
146
+ " )\n",
147
+ " (2): UpDecoderBlock2D(\n",
148
+ " (resnets): ModuleList(\n",
149
+ " (0): ResnetBlock2D(\n",
150
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
151
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
152
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
153
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
154
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
155
+ " (nonlinearity): SiLU()\n",
156
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
157
+ " )\n",
158
+ " (1-3): 3 x ResnetBlock2D(\n",
159
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
160
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
161
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
162
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
163
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
164
+ " (nonlinearity): SiLU()\n",
165
+ " )\n",
166
+ " )\n",
167
+ " (upsamplers): ModuleList(\n",
168
+ " (0): Upsample2D(\n",
169
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
170
+ " )\n",
171
+ " )\n",
172
+ " )\n",
173
+ " (3): UpDecoderBlock2D(\n",
174
+ " (resnets): ModuleList(\n",
175
+ " (0): ResnetBlock2D(\n",
176
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
177
+ " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
178
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
179
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
180
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
181
+ " (nonlinearity): SiLU()\n",
182
+ " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
183
+ " )\n",
184
+ " (1-3): 3 x ResnetBlock2D(\n",
185
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
186
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
187
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
188
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
189
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
190
+ " (nonlinearity): SiLU()\n",
191
+ " )\n",
192
+ " )\n",
193
+ " )\n",
194
+ " )\n",
195
+ " (mid_block): UNetMidBlock2D(\n",
196
+ " (attentions): ModuleList(\n",
197
+ " (0): Attention(\n",
198
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
199
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
200
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
201
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
202
+ " (to_out): ModuleList(\n",
203
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
204
+ " (1): Dropout(p=0.0, inplace=False)\n",
205
+ " )\n",
206
+ " )\n",
207
+ " )\n",
208
+ " (resnets): ModuleList(\n",
209
+ " (0-1): 2 x ResnetBlock2D(\n",
210
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
211
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
212
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
213
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
214
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
215
+ " (nonlinearity): SiLU()\n",
216
+ " )\n",
217
+ " )\n",
218
+ " )\n",
219
+ " (condition_encoder): MaskConditionEncoder(\n",
220
+ " (layers): Sequential(\n",
221
+ " (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
222
+ " (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
223
+ " (2): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
224
+ " (3): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
225
+ " (4): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
226
+ " )\n",
227
+ " )\n",
228
+ " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
229
+ " (conv_act): SiLU()\n",
230
+ " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
231
+ " )\n",
232
+ " (quant_conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))\n",
233
+ " (post_quant_conv): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))\n",
234
+ ")\n"
235
+ ]
236
+ }
237
+ ],
238
+ "source": [
239
+ "from diffusers.models import AsymmetricAutoencoderKL\n",
240
+ "import torch\n",
241
+ "\n",
242
+ "config = {\n",
243
+ " \"_class_name\": \"AsymmetricAutoencoderKL\",\n",
244
+ " \"act_fn\": \"silu\",\n",
245
+ " \"down_block_out_channels\": [128, 256, 512, 512],\n",
246
+ " \"down_block_types\": [\n",
247
+ " \"DownEncoderBlock2D\",\n",
248
+ " \"DownEncoderBlock2D\",\n",
249
+ " \"DownEncoderBlock2D\",\n",
250
+ " \"DownEncoderBlock2D\",\n",
251
+ " ],\n",
252
+ " \"in_channels\": 3,\n",
253
+ " \"latent_channels\": 4,\n",
254
+ " \"norm_num_groups\": 32,\n",
255
+ " \"out_channels\": 3,\n",
256
+ " \"sample_size\": 1024,\n",
257
+ " \"scaling_factor\": 0.13025,\n",
258
+ " \"shift_factor\": 0,\n",
259
+ " \"up_block_out_channels\": [128, 256, 512, 512],\n",
260
+ " \"up_block_types\": [\n",
261
+ " \"UpDecoderBlock2D\",\n",
262
+ " \"UpDecoderBlock2D\",\n",
263
+ " \"UpDecoderBlock2D\",\n",
264
+ " \"UpDecoderBlock2D\",\n",
265
+ " ],\n",
266
+ "}\n",
267
+ "\n",
268
+ "# Создаем модель\n",
269
+ "vae = AsymmetricAutoencoderKL(\n",
270
+ " act_fn=config[\"act_fn\"],\n",
271
+ " down_block_out_channels=config[\"down_block_out_channels\"],\n",
272
+ " down_block_types=config[\"down_block_types\"],\n",
273
+ " in_channels=config[\"in_channels\"],\n",
274
+ " latent_channels=config[\"latent_channels\"],\n",
275
+ " norm_num_groups=config[\"norm_num_groups\"],\n",
276
+ " out_channels=config[\"out_channels\"],\n",
277
+ " sample_size=config[\"sample_size\"],\n",
278
+ " scaling_factor=config[\"scaling_factor\"],\n",
279
+ " up_block_out_channels=config[\"up_block_out_channels\"],\n",
280
+ " up_block_types=config[\"up_block_types\"],\n",
281
+ " layers_per_down_block = 2,\n",
282
+ " layers_per_up_block = 3\n",
283
+ ")\n",
284
+ "\n",
285
+ "\n",
286
+ "vae.save_pretrained(\"asymmetric_vae_empty\")\n",
287
+ "print(vae)"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 3,
293
+ "id": "a2950158-5203-42b9-8791-e231ddbf1063",
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "name": "stderr",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
301
+ "Перенос весов: 100%|██████████| 248/248 [00:00<00:00, 30427.29it/s]\n"
302
+ ]
303
+ },
304
+ {
305
+ "name": "stdout",
306
+ "output_type": "stream",
307
+ "text": [
308
+ "Статистика переноса: {'перенесено': 248, 'несовпадение_размеров': 0, 'пропущено': 0}\n",
309
+ "Неперенесенные ключи в новой модели:\n",
310
+ "decoder.condition_encoder.layers.0.bias\n",
311
+ "decoder.condition_encoder.layers.0.weight\n",
312
+ "decoder.condition_encoder.layers.1.bias\n",
313
+ "decoder.condition_encoder.layers.1.weight\n",
314
+ "decoder.condition_encoder.layers.2.bias\n",
315
+ "decoder.condition_encoder.layers.2.weight\n",
316
+ "decoder.condition_encoder.layers.3.bias\n",
317
+ "decoder.condition_encoder.layers.3.weight\n",
318
+ "decoder.condition_encoder.layers.4.bias\n",
319
+ "decoder.condition_encoder.layers.4.weight\n",
320
+ "decoder.up_blocks.0.resnets.3.conv1.bias\n",
321
+ "decoder.up_blocks.0.resnets.3.conv1.weight\n",
322
+ "decoder.up_blocks.0.resnets.3.conv2.bias\n",
323
+ "decoder.up_blocks.0.resnets.3.conv2.weight\n",
324
+ "decoder.up_blocks.0.resnets.3.norm1.bias\n",
325
+ "decoder.up_blocks.0.resnets.3.norm1.weight\n",
326
+ "decoder.up_blocks.0.resnets.3.norm2.bias\n",
327
+ "decoder.up_blocks.0.resnets.3.norm2.weight\n",
328
+ "decoder.up_blocks.1.resnets.3.conv1.bias\n",
329
+ "decoder.up_blocks.1.resnets.3.conv1.weight\n",
330
+ "decoder.up_blocks.1.resnets.3.conv2.bias\n",
331
+ "decoder.up_blocks.1.resnets.3.conv2.weight\n",
332
+ "decoder.up_blocks.1.resnets.3.norm1.bias\n",
333
+ "decoder.up_blocks.1.resnets.3.norm1.weight\n",
334
+ "decoder.up_blocks.1.resnets.3.norm2.bias\n",
335
+ "decoder.up_blocks.1.resnets.3.norm2.weight\n",
336
+ "decoder.up_blocks.2.resnets.3.conv1.bias\n",
337
+ "decoder.up_blocks.2.resnets.3.conv1.weight\n",
338
+ "decoder.up_blocks.2.resnets.3.conv2.bias\n",
339
+ "decoder.up_blocks.2.resnets.3.conv2.weight\n",
340
+ "decoder.up_blocks.2.resnets.3.norm1.bias\n",
341
+ "decoder.up_blocks.2.resnets.3.norm1.weight\n",
342
+ "decoder.up_blocks.2.resnets.3.norm2.bias\n",
343
+ "decoder.up_blocks.2.resnets.3.norm2.weight\n",
344
+ "decoder.up_blocks.3.resnets.3.conv1.bias\n",
345
+ "decoder.up_blocks.3.resnets.3.conv1.weight\n",
346
+ "decoder.up_blocks.3.resnets.3.conv2.bias\n",
347
+ "decoder.up_blocks.3.resnets.3.conv2.weight\n",
348
+ "decoder.up_blocks.3.resnets.3.norm1.bias\n",
349
+ "decoder.up_blocks.3.resnets.3.norm1.weight\n",
350
+ "decoder.up_blocks.3.resnets.3.norm2.bias\n",
351
+ "decoder.up_blocks.3.resnets.3.norm2.weight\n"
352
+ ]
353
+ }
354
+ ],
355
+ "source": [
356
+ "import torch\n",
357
+ "from diffusers import AsymmetricAutoencoderKL,AutoencoderKL\n",
358
+ "from tqdm import tqdm\n",
359
+ "import torch.nn.init as init\n",
360
+ "\n",
361
+ "def log(message):\n",
362
+ " print(message)\n",
363
+ "\n",
364
+ "def main():\n",
365
+ " checkpoint_path_old = \"vae\"\n",
366
+ " checkpoint_path_new = \"asymmetric_vae_empty\"\n",
367
+ " device = \"cuda\"\n",
368
+ " dtype = torch.float32\n",
369
+ "\n",
370
+ " # Загрузка моделей\n",
371
+ " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
372
+ " new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
373
+ "\n",
374
+ " old_state_dict = old_unet.state_dict()\n",
375
+ " new_state_dict = new_unet.state_dict()\n",
376
+ "\n",
377
+ " transferred_state_dict = {}\n",
378
+ " transfer_stats = {\n",
379
+ " \"перенесено\": 0,\n",
380
+ " \"несовпадение_размеров\": 0,\n",
381
+ " \"пропущено\": 0\n",
382
+ " }\n",
383
+ "\n",
384
+ " transferred_keys = set()\n",
385
+ "\n",
386
+ " # Обрабатываем каждый ключ старой модели\n",
387
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
388
+ " new_key = old_key\n",
389
+ "\n",
390
+ " if new_key in new_state_dict:\n",
391
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
392
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
393
+ " transferred_keys.add(new_key)\n",
394
+ " transfer_stats[\"перенесено\"] += 1\n",
395
+ " else:\n",
396
+ " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n",
397
+ " transfer_stats[\"несовпадение_размеров\"] += 1\n",
398
+ " else:\n",
399
+ " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n",
400
+ " transfer_stats[\"пропущено\"] += 1\n",
401
+ "\n",
402
+ " # Обновляем состояние новой модели перенесенными весами\n",
403
+ " new_state_dict.update(transferred_state_dict)\n",
404
+ " \n",
405
+ " # Инициализируем веса для нового mid блока\n",
406
+ " #new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n",
407
+ " \n",
408
+ " new_unet.load_state_dict(new_state_dict)\n",
409
+ " new_unet.save_pretrained(\"asymmetric_vae\")\n",
410
+ "\n",
411
+ " # Получаем список неперенесенных ключей\n",
412
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
413
+ "\n",
414
+ " print(\"Статистика переноса:\", transfer_stats)\n",
415
+ " print(\"Неперенесенные ключи в новой модели:\")\n",
416
+ " for key in non_transferred_keys:\n",
417
+ " print(key)\n",
418
+ "\n",
419
+ "if __name__ == \"__main__\":\n",
420
+ " main()"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": 1,
426
+ "id": "b316ee6c-d295-4396-9177-78e39a53055b",
427
+ "metadata": {},
428
+ "outputs": [
429
+ {
430
+ "name": "stderr",
431
+ "output_type": "stream",
432
+ "text": [
433
+ "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
434
+ ]
435
+ },
436
+ {
437
+ "name": "stdout",
438
+ "output_type": "stream",
439
+ "text": [
440
+ "ok\n"
441
+ ]
442
+ }
443
+ ],
444
+ "source": [
445
+ "import torch\n",
446
+ "\n",
447
+ "from torchvision import transforms, utils\n",
448
+ "\n",
449
+ "import diffusers\n",
450
+ "from diffusers import AsymmetricAutoencoderKL\n",
451
+ "\n",
452
+ "from diffusers.utils import load_image\n",
453
+ "\n",
454
+ "def crop_image_to_nearest_divisible_by_8(img):\n",
455
+ " # Check if the image height and width are divisible by 8\n",
456
+ " if img.shape[1] % 8 == 0 and img.shape[2] % 8 == 0:\n",
457
+ " return img\n",
458
+ " else:\n",
459
+ " # Calculate the closest lower resolution divisible by 8\n",
460
+ " new_height = img.shape[1] - (img.shape[1] % 8)\n",
461
+ " new_width = img.shape[2] - (img.shape[2] % 8)\n",
462
+ " \n",
463
+ " # Use CenterCrop to crop the image\n",
464
+ " transform = transforms.CenterCrop((new_height, new_width), interpolation=transforms.InterpolationMode.BILINEAR)\n",
465
+ " img = transform(img).to(torch.float32).clamp(-1, 1)\n",
466
+ " \n",
467
+ " return img\n",
468
+ " \n",
469
+ "to_tensor = transforms.ToTensor()\n",
470
+ "\n",
471
+ "device = \"cuda\"\n",
472
+ "dtype=torch.float16\n",
473
+ "vae = AsymmetricAutoencoderKL.from_pretrained(\"asymmetric_vae\",torch_dtype=dtype).to(device).eval()\n",
474
+ "\n",
475
+ "image = load_image(\"123456789.jpg\")\n",
476
+ "\n",
477
+ "image = crop_image_to_nearest_divisible_by_8(to_tensor(image)).unsqueeze(0).to(device,dtype=dtype)\n",
478
+ "\n",
479
+ "upscaled_image = vae(image).sample\n",
480
+ "#vae.config.scaled_factor\n",
481
+ "# Save the reconstructed image\n",
482
+ "utils.save_image(upscaled_image, \"test.png\")\n",
483
+ "print('ok')"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": null,
489
+ "id": "5a01b8e9-73c9-4da7-a097-e334019bd8e9",
490
+ "metadata": {},
491
+ "outputs": [],
492
+ "source": []
493
+ }
494
+ ],
495
+ "metadata": {
496
+ "kernelspec": {
497
+ "display_name": "Python 3 (ipykernel)",
498
+ "language": "python",
499
+ "name": "python3"
500
+ },
501
+ "language_info": {
502
+ "codemirror_mode": {
503
+ "name": "ipython",
504
+ "version": 3
505
+ },
506
+ "file_extension": ".py",
507
+ "mimetype": "text/x-python",
508
+ "name": "python",
509
+ "nbconvert_exporter": "python",
510
+ "pygments_lexer": "ipython3",
511
+ "version": "3.11.10"
512
+ }
513
+ },
514
+ "nbformat": 4,
515
+ "nbformat_minor": 5
516
+ }
samples/sample_0_0.jpg ADDED

Git LFS Details

  • SHA256: fa157903dd5a4118d9c38e32c25c5a02a3eeaddb59d3a1c9d8fe7e9eb57e3f14
  • Pointer size: 130 Bytes
  • Size of remote file: 98 kB
samples/sample_0_1.jpg ADDED

Git LFS Details

  • SHA256: 7cba73cbeeb41f97f6247043e00a5346cf10f6bf67f4ffa4ac8a736c6841a2be
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
samples/sample_0_2.jpg ADDED

Git LFS Details

  • SHA256: 2cdfd5107c48e41eb4d9475b9360f2c5a98b25509649e37df9eac75065ffbd96
  • Pointer size: 130 Bytes
  • Size of remote file: 93.4 kB
samples/sample_673_0.jpg ADDED

Git LFS Details

  • SHA256: ecb6610fe8119c402581c2181181aea871f7a6f3a211b48c1927cea878d9babb
  • Pointer size: 130 Bytes
  • Size of remote file: 95.5 kB
samples/sample_673_1.jpg ADDED

Git LFS Details

  • SHA256: e370fb4119a38245baad69f7e243506d69e40437878253e91d683ebba1f443af
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
samples/sample_673_2.jpg ADDED

Git LFS Details

  • SHA256: ff7edcb0dbc7a36cd3a5a344e4a47b6e13ea1153455c115b738025beb2d45fbc
  • Pointer size: 130 Bytes
  • Size of remote file: 90.3 kB
sdxl_vae_a1111.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebe642d26e14851e98eb3d06575009e0d1a669704a1c9c8dcf06573d82233a21
3
+ size 334640988
test.png ADDED

Git LFS Details

  • SHA256: c05b66cf82ccaa12d60e97b3c1898e2c52cd815fb9315473a15b1693a0227799
  • Pointer size: 131 Bytes
  • Size of remote file: 947 kB
train_sdxl_vae.py CHANGED
@@ -23,21 +23,22 @@ import lpips # pip install lpips
23
 
24
  # --------------------------- Параметры ---------------------------
25
  ds_path = "/workspace/png"
26
- project = "sdxl_vae"
27
- batch_size = 1
28
  base_learning_rate = 1e-6
29
  min_learning_rate = 8e-7
30
  num_epochs = 8
31
- sample_interval_share = 20
32
  use_wandb = True
33
  save_model = True
34
  use_decay = True
 
35
  optimizer_type = "adam8bit"
36
  dtype = torch.float32
37
  # model_resolution — то, что подавается в VAE (низкое разрешение)
38
- model_resolution = 768 # бывший `resolution`
39
  # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
40
- high_resolution = 768 # >>> CHANGED: обучаемся на входах 1024 -> даунсемплим до 512 для модели
41
  limit = 0
42
  save_barrier = 1.03
43
  warmup_percent = 0.01
@@ -46,9 +47,9 @@ beta2 = 0.97
46
  eps = 1e-6
47
  clip_grad_norm = 1.0
48
  mixed_precision = "no" # или "fp16"/"bf16" при поддержке
49
- gradient_accumulation_steps = 16
50
  generated_folder = "samples"
51
- save_as = "sdxl_vae_new"
52
  perceptual_loss_weight = 0.03 # начальное значение веса (будет перезаписываться каждый шаг)
53
  num_workers = 0
54
  device = None # accelerator задаст устройство
@@ -91,8 +92,10 @@ if use_wandb and accelerator.is_main_process:
91
  })
92
 
93
  # --------------------------- VAE ---------------------------
94
- vae = AutoencoderKL.from_pretrained(project).to(dtype)
95
- #vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
 
 
96
 
97
  # >>> CHANGED: заморозка всех параметров, затем разморозка mid_block + up_blocks[-2:]
98
  for p in vae.parameters():
@@ -109,7 +112,7 @@ if not hasattr(decoder, "up_blocks"):
109
 
110
  # >>> CHANGED: размораживаем последние 2 up_blocks (как просил) и mid_block
111
  n_up = len(decoder.up_blocks)
112
- start_idx = 0 #max(0, n_up - 2)
113
  for idx in range(start_idx, n_up):
114
  block = decoder.up_blocks[idx]
115
  for name, p in block.named_parameters():
 
23
 
24
  # --------------------------- Параметры ---------------------------
25
  ds_path = "/workspace/png"
26
+ project = "asymmetric_vae"
27
+ batch_size = 2
28
  base_learning_rate = 1e-6
29
  min_learning_rate = 8e-7
30
  num_epochs = 8
31
+ sample_interval_share = 10
32
  use_wandb = True
33
  save_model = True
34
  use_decay = True
35
+ asymmetric = True
36
  optimizer_type = "adam8bit"
37
  dtype = torch.float32
38
  # model_resolution — то, что подавается в VAE (низкое разрешение)
39
+ model_resolution = 512 # бывший `resolution`
40
  # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
41
+ high_resolution = 512 # >>> CHANGED: обучаемся на входах 1024 -> даунсемплим до 512 для модели
42
  limit = 0
43
  save_barrier = 1.03
44
  warmup_percent = 0.01
 
47
  eps = 1e-6
48
  clip_grad_norm = 1.0
49
  mixed_precision = "no" # или "fp16"/"bf16" при поддержке
50
+ gradient_accumulation_steps = 8
51
  generated_folder = "samples"
52
+ save_as = "asymmetric_vae_new"
53
  perceptual_loss_weight = 0.03 # начальное значение веса (будет перезаписываться каждый шаг)
54
  num_workers = 0
55
  device = None # accelerator задаст устройство
 
92
  })
93
 
94
  # --------------------------- VAE ---------------------------
95
+ if model_resolution==high_resolution and not asymmetric:
96
+ vae = AutoencoderKL.from_pretrained(project).to(dtype)
97
+ else:
98
+ vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
99
 
100
  # >>> CHANGED: заморозка всех параметров, затем разморозка mid_block + up_blocks[-2:]
101
  for p in vae.parameters():
 
112
 
113
  # >>> CHANGED: размораживаем последние 2 up_blocks (как просил) и mid_block
114
  n_up = len(decoder.up_blocks)
115
+ start_idx = 0 #max(0, n_up - 2) # all
116
  for idx in range(start_idx, n_up):
117
  block = decoder.up_blocks[idx]
118
  for name, p in block.named_parameters():
vae.png ADDED

Git LFS Details

  • SHA256: 70f3a3c4e9c5e51947ed3529e6e2ab62e513b91b102b11c4e742b11736c14f13
  • Pointer size: 132 Bytes
  • Size of remote file: 2.26 MB
vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.34.0",
4
+ "_name_or_path": "sdxl_vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": false,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 512,
28
+ "scaling_factor": 0.13025,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f17d5c9503862b25a273b8874851a99de817dbfae6094432f51381bb1cdd60c8
3
+ size 334643268