recoilme commited on
Commit
bf7a967
·
1 Parent(s): f8cf9ec
README.md CHANGED
@@ -9,7 +9,7 @@ library_name: diffusers
9
  ---
10
  # SDXL-VAE finetuned
11
 
12
-
13
  | Model | MSE | PSNR | LPIPS |
14
  |----------------------------|-------------|-----------|------------|
15
  | madebyollin/sdxl-vae-fp16-fix | 3.680e-03 | 25.2100 | 0.1314 |
@@ -17,6 +17,16 @@ library_name: diffusers
17
  | **AiArtLab/sdxl_vae** | **3.321e-03** | **25.6389** | **0.1251** |
18
 
19
 
 
 
 
 
 
 
 
 
 
 
20
  [![Click it](vae.png)](https://imgsli.com/NDA3OTgz)
21
 
22
 
 
9
  ---
10
  # SDXL-VAE finetuned
11
 
12
+ Imagenet eval (256px)
13
  | Model | MSE | PSNR | LPIPS |
14
  |----------------------------|-------------|-----------|------------|
15
  | madebyollin/sdxl-vae-fp16-fix | 3.680e-03 | 25.2100 | 0.1314 |
 
17
  | **AiArtLab/sdxl_vae** | **3.321e-03** | **25.6389** | **0.1251** |
18
 
19
 
20
+ Alchemist eval (512px)
21
+
22
+ | Model | MSE | PSNR | LPIPS |
23
+ |--------------------------------|------------|------------|------------|
24
+ | madebyollin/sdxl-vae-fp16 | 100% | 100% | 100% |
25
+ | KBlueLeaf/EQ-SDXL-VAE | 107.8% | 100.1% | 95.5% |
26
+ | AiArtLab/sdxl_vae | 112.3% | 101.8% | 106.6% |
27
+ | AiArtLab/sdxl_vae_asym | 111.7% | 101.1% | 89.4% |
28
+ | FLUX.1-schnell-vae | 324.0% | 119.8% | 292.0% |
29
+
30
  [![Click it](vae.png)](https://imgsli.com/NDA3OTgz)
31
 
32
 
asymmetric_vae/config.json CHANGED
@@ -3,12 +3,6 @@
3
  "_diffusers_version": "0.34.0",
4
  "_name_or_path": "asymmetric_vae_empty",
5
  "act_fn": "silu",
6
- "block_out_channels": [
7
- 128,
8
- 256,
9
- 512,
10
- 512
11
- ],
12
  "down_block_out_channels": [
13
  128,
14
  256,
@@ -21,7 +15,6 @@
21
  "DownEncoderBlock2D",
22
  "DownEncoderBlock2D"
23
  ],
24
- "force_upcast": false,
25
  "in_channels": 3,
26
  "latent_channels": 4,
27
  "layers_per_down_block": 2,
 
3
  "_diffusers_version": "0.34.0",
4
  "_name_or_path": "asymmetric_vae_empty",
5
  "act_fn": "silu",
 
 
 
 
 
 
6
  "down_block_out_channels": [
7
  128,
8
  256,
 
15
  "DownEncoderBlock2D",
16
  "DownEncoderBlock2D"
17
  ],
 
18
  "in_channels": 3,
19
  "latent_channels": 4,
20
  "layers_per_down_block": 2,
asymmetric_vae/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8de44e4f21835eb457785a63f7e96c7ddba34b9b812bdeee79012d8bd0dae199
3
  size 421473052
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ded3c30322578e3371f32a58423b6a3be3a2c3b81d3eb5d35433772be796a1ba
3
  size 421473052
asymmetric_vae_new/config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "_class_name": "AsymmetricAutoencoderKL",
3
- "_diffusers_version": "0.34.0",
4
  "_name_or_path": "asymmetric_vae",
5
  "act_fn": "silu",
6
  "block_out_channels": [
 
1
  {
2
  "_class_name": "AsymmetricAutoencoderKL",
3
+ "_diffusers_version": "0.35.0.dev0",
4
  "_name_or_path": "asymmetric_vae",
5
  "act_fn": "silu",
6
  "block_out_channels": [
asymmetric_vae_new/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ded3c30322578e3371f32a58423b6a3be3a2c3b81d3eb5d35433772be796a1ba
3
  size 421473052
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df9380b1e8d8b1a36b3d0f9501a854717a911ae9b8d2aebe18809a6eefa9318b
3
  size 421473052
down.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ TARGET_DIR="/workspace/d23"
4
+ mkdir -p "$TARGET_DIR"
5
+
6
+ BASE_URL="https://huggingface.co/datasets/AI-Art-Collab/dtasettar23/resolve/main/d23.tar."
7
+
8
+ (
9
+ # Устанавливаем `set -e` внутри subshell, чтобы он завершился при первой ошибке curl
10
+ set -e
11
+ # Попробуем от 'a' до 'z' для первого символа суффикса
12
+ for c1 in {a..z}; do
13
+ # Попробуем от 'a' до 'z' для второго символа суффикса
14
+ for c2 in {a..z}; do
15
+ suffix="${c1}${c2}"
16
+ url="${BASE_URL}${suffix}"
17
+ echo "Fetching: $url" >&2
18
+ # Качаем часть архива. --fail заставит curl завершиться с ошибкой, если файла нет.
19
+ curl -LsS --fail "$url"
20
+ done
21
+ done
22
+ ) 2>/dev/null | tar -xv -C "$TARGET_DIR" --wildcards '*.png'
23
+ # └─ 1 ─┘ └────────── 2 ──────────┘ └─────────── 3 ───────────┘
24
+
25
+ echo "Extraction of PNG files finished. Check $TARGET_DIR"
eval_alchemist.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import lpips
5
+ from PIL import Image, UnidentifiedImageError
6
+ from tqdm import tqdm
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop
9
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
10
+ import random
11
+
12
+ # --------------------------- Параметры ---------------------------
13
+ DEVICE = "cuda"
14
+ DTYPE = torch.float16
15
+ IMAGE_FOLDER = "/workspace/alchemist"
16
+ MIN_SIZE = 1280
17
+ CROP_SIZE = 512
18
+ BATCH_SIZE = 4 # можно увеличить для ускорения
19
+ MAX_IMAGES = None
20
+ NUM_WORKERS = 4 # параллельная загрузка
21
+
22
+ # Список VAE для тестирования
23
+ VAE_LIST = [
24
+ ("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
25
+ ("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None),
26
+ ("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None),
27
+ ("AiArtLab/sdxl_vae_asym", AsymmetricAutoencoderKL, "AiArtLab/sdxl_vae", "asymmetric_vae"),
28
+ ("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
29
+ ]
30
+
31
+ # --------------------------- Dataset ---------------------------
32
+ class ImageFolderDataset(Dataset):
33
+ def __init__(self, root_dir, extensions=('.png',), min_size=1024, crop_size=512, limit=None):
34
+ self.root_dir = root_dir
35
+ self.min_size = min_size
36
+ self.crop_size = crop_size
37
+ self.paths = []
38
+
39
+ # Собираем пути к файлам
40
+ print("Сканирование папки...")
41
+ for root, _, files in os.walk(root_dir):
42
+ for fname in files:
43
+ if fname.lower().endswith(extensions):
44
+ self.paths.append(os.path.join(root, fname))
45
+
46
+ # Ограничение количества
47
+ if limit:
48
+ self.paths = self.paths[:limit]
49
+
50
+ # Быстрая проверка валидности (опционально, можно убрать для скорости)
51
+ print("Проверка изображений...")
52
+ valid = []
53
+ for p in tqdm(self.paths, desc="Проверка"):
54
+ try:
55
+ with Image.open(p) as im:
56
+ im.verify()
57
+ valid.append(p)
58
+ except:
59
+ continue
60
+ self.paths = valid
61
+
62
+ if len(self.paths) == 0:
63
+ raise RuntimeError(f"Не найдено валидных изображений в {root_dir}")
64
+
65
+ # Перемешиваем для случайности
66
+ random.shuffle(self.paths)
67
+ print(f"Найдено {len(self.paths)} изображений")
68
+
69
+ # Трансформации
70
+ self.transform = Compose([
71
+ Resize(min_size, interpolation=Image.LANCZOS),
72
+ CenterCrop(crop_size),
73
+ ToTensor(),
74
+ ])
75
+
76
+ def __len__(self):
77
+ return len(self.paths)
78
+
79
+ def __getitem__(self, idx):
80
+ path = self.paths[idx]
81
+ with Image.open(path) as img:
82
+ img = img.convert("RGB")
83
+ return self.transform(img)
84
+
85
+ # --------------------------- Функции ---------------------------
86
+ def process(x):
87
+ return x * 2 - 1
88
+
89
+ def deprocess(x):
90
+ return x * 0.5 + 0.5
91
+
92
+ # --------------------------- Основной код ---------------------------
93
+ if __name__ == "__main__":
94
+ # Создаем датасет и загрузчик
95
+ dataset = ImageFolderDataset(
96
+ IMAGE_FOLDER,
97
+ extensions=('.png',),
98
+ min_size=MIN_SIZE,
99
+ crop_size=CROP_SIZE,
100
+ limit=MAX_IMAGES
101
+ )
102
+
103
+ dataloader = DataLoader(
104
+ dataset,
105
+ batch_size=BATCH_SIZE,
106
+ shuffle=False, # уже перемешали в датасете
107
+ num_workers=NUM_WORKERS,
108
+ pin_memory=True,
109
+ drop_last=False
110
+ )
111
+
112
+ # Инициализация LPIPS
113
+ lpips_net = lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False)
114
+
115
+ # Загрузка VAE моделей
116
+ print("\nЗагрузка VAE моделей...")
117
+ vaes = []
118
+ names = []
119
+
120
+ for name, vae_class, model_path, subfolder in VAE_LIST:
121
+ try:
122
+ print(f" Загружаю {name}...")
123
+ vae = vae_class.from_pretrained(model_path, subfolder=subfolder)
124
+ vae = vae.to(DEVICE, DTYPE).eval()
125
+ vaes.append(vae)
126
+ names.append(name)
127
+ except Exception as e:
128
+ print(f" ❌ Ошибка загрузки {name}: {e}")
129
+
130
+ # Оценка метрик
131
+ print("\nОценка метрик...")
132
+ results = {name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "count": 0} for name in names}
133
+
134
+ with torch.no_grad():
135
+ for batch in tqdm(dataloader, desc="Обработка батчей"):
136
+ batch = batch.to(DEVICE)
137
+ test_inp = process(batch).to(DTYPE)
138
+
139
+ for vae, name in zip(vaes, names):
140
+ # Encode/decode
141
+ latent = vae.encode(test_inp).latent_dist.mode()
142
+ recon = deprocess(vae.decode(latent).sample.float())
143
+
144
+ # Метрики для батча
145
+ for i in range(batch.shape[0]):
146
+ img_orig = batch[i:i+1]
147
+ img_recon = recon[i:i+1]
148
+
149
+ mse = F.mse_loss(img_orig, img_recon).item()
150
+ psnr = 10 * torch.log10(1 / torch.tensor(mse)).item()
151
+ lpips_val = lpips_net(img_orig, img_recon, normalize=True).mean().item()
152
+
153
+ results[name]["mse"] += mse
154
+ results[name]["psnr"] += psnr
155
+ results[name]["lpips"] += lpips_val
156
+ results[name]["count"] += 1
157
+
158
+ # Усреднение результатов
159
+ for name in names:
160
+ count = results[name]["count"]
161
+ results[name]["mse"] /= count
162
+ results[name]["psnr"] /= count
163
+ results[name]["lpips"] /= count
164
+
165
+ # Вывод абсолютных значений
166
+ print("\n=== Абсолютные значения ===")
167
+ for name in names:
168
+ print(f"{name:30s}: MSE: {results[name]['mse']:.3e}, PSNR: {results[name]['psnr']:.4f}, LPIPS: {results[name]['lpips']:.4f}")
169
+
170
+ # Вывод таблицы с процентами
171
+ print("\n=== Сравнение с первой моделью (%) ===")
172
+ print(f"| {'Модель':30s} | {'MSE':>10s} | {'PSNR':>10s} | {'LPIPS':>10s} |")
173
+ print(f"|{'-'*32}|{'-'*12}|{'-'*12}|{'-'*12}|")
174
+
175
+ baseline = names[0]
176
+ for name in names:
177
+ mse_pct = (results[baseline]["mse"] / results[name]["mse"]) * 100
178
+ psnr_pct = (results[name]["psnr"] / results[baseline]["psnr"]) * 100
179
+ lpips_pct = (results[baseline]["lpips"] / results[name]["lpips"]) * 100
180
+
181
+ if name == baseline:
182
+ print(f"| {name:30s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} |")
183
+ else:
184
+ print(f"| {name:30s} | {f'{mse_pct:.1f}%':>10s} | {f'{psnr_pct:.1f}%':>10s} | {f'{lpips_pct:.1f}%':>10s} |")
185
+
186
+ print("\n✅ Готово!")
eval_asym.py DELETED
@@ -1,159 +0,0 @@
1
- import warnings
2
- import logging
3
- import torch
4
- import torch.nn.functional as F
5
- import torch.utils.data as data
6
- import lpips
7
- from tqdm import tqdm
8
- from torchvision.transforms import (
9
- Compose,
10
- Resize,
11
- ToTensor,
12
- CenterCrop,
13
- )
14
- from diffusers import AutoencoderKL,AsymmetricAutoencoderKL
15
-
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
- warnings.filterwarnings(
20
- "ignore",
21
- ".*Found keys that are not in the model state dict but in the checkpoint.*",
22
- )
23
-
24
- DEVICE = "cuda"
25
- DTYPE = torch.float16
26
- SHORT_AXIS_SIZE = 256
27
- batch_size = 1
28
-
29
- NAMES = [
30
- # "asymmetric_vae",
31
- "asymmetric_vae_new",
32
- # "madebyollin/sdxl-vae-fp16-fix",
33
- # "KBlueLeaf/EQ-SDXL-VAE ",
34
- # "AiArtLab/simplevae ",
35
- ]
36
- BASE_MODELS = [
37
- # "./asymmetric_vae",
38
- "./asymmetric_vae_new",
39
- # "madebyollin/sdxl-vae-fp16-fix",
40
- # "KBlueLeaf/EQ-SDXL-VAE",
41
- # "AiArtLab/simplevae",
42
- ]
43
- SUB_FOLDERS = [
44
- # "sdxs_vae",
45
- None,
46
- # None,
47
- # "sdxl_vae"
48
- ]
49
-
50
- def process(x):
51
- return x * 2 - 1
52
-
53
- def deprocess(x):
54
- return x * 0.5 + 0.5
55
-
56
- import torch.utils.data as data
57
- from datasets import load_dataset
58
-
59
- class ImageNetDataset(data.IterableDataset):
60
- def __init__(self, split, transform=None, max_len=10, streaming=True):
61
- self.split = split
62
- self.transform = transform
63
- self.dataset = load_dataset("evanarlian/imagenet_1k_resized_256", split=split, streaming=streaming)
64
- self.max_len = max_len
65
- self.iterator = iter(self.dataset)
66
-
67
- def __iter__(self):
68
- for i, entry in enumerate(self.iterator):
69
- if self.max_len and i >= self.max_len:
70
- break
71
- img = entry["image"]
72
- target = entry["label"]
73
- if self.transform is not None:
74
- img = self.transform(img)
75
- yield img, target
76
-
77
- if __name__ == "__main__":
78
- lpips_loss = torch.compile(
79
- lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False)
80
- )
81
-
82
- @torch.compile
83
- def metrics(inp, recon):
84
- mse = F.mse_loss(inp, recon)
85
- psnr = 10 * torch.log10(1 / mse)
86
- return (
87
- mse.cpu(),
88
- psnr.cpu(),
89
- lpips_loss(inp, recon, normalize=True).mean().cpu(),
90
- )
91
-
92
- transform = Compose(
93
- [
94
- Resize(SHORT_AXIS_SIZE),
95
- CenterCrop(SHORT_AXIS_SIZE),
96
- ToTensor(),
97
- ]
98
- )
99
- valid_dataset = ImageNetDataset("val", transform=transform, max_len=50000, streaming=True)
100
- valid_loader = data.DataLoader(
101
- valid_dataset,
102
- batch_size=batch_size,
103
- shuffle=False,
104
- num_workers=2,
105
- pin_memory=True,
106
- pin_memory_device=DEVICE,
107
- )
108
-
109
- # Проверяем, что данные грузятся
110
- for batch in valid_loader:
111
- print("Batch shape:", batch[0].shape)
112
- break
113
-
114
- logger.info("Loading models...")
115
- vaes = []
116
- for base_model, sub_folder in zip(
117
- BASE_MODELS, SUB_FOLDERS
118
- ):
119
- vae = AsymmetricAutoencoderKL.from_pretrained(base_model, subfolder=sub_folder)
120
- vae = vae.to(DTYPE).eval().requires_grad_(False).to(DEVICE)
121
- vae.encoder = torch.compile(vae.encoder)
122
- vae.decoder = torch.compile(vae.decoder)
123
- vaes.append(torch.compile(vae))
124
-
125
- logger.info("Running Validation")
126
- total = 0
127
- all_latents = [[] for _ in range(len(vaes))]
128
- all_mse = [[] for _ in range(len(vaes))]
129
- all_psnr = [[] for _ in range(len(vaes))]
130
- all_lpips = [[] for _ in range(len(vaes))]
131
-
132
- for idx, batch in enumerate(tqdm(valid_loader)):
133
- image = batch[0].to(DEVICE)
134
- test_inp = process(image).to(DTYPE)
135
- batch_size = test_inp.size(0)
136
-
137
- for i, vae in enumerate(vaes):
138
- latent = vae.encode(test_inp).latent_dist.mode()
139
- recon = deprocess(vae.decode(latent).sample.float())
140
- all_latents[i].append(latent.cpu().float())
141
- mse, psnr, lpips_ = metrics(image, recon)
142
- all_mse[i].append(mse.cpu() * batch_size)
143
- all_psnr[i].append(psnr.cpu() * batch_size)
144
- all_lpips[i].append(lpips_.cpu() * batch_size)
145
-
146
- total += batch_size
147
-
148
- for i in range(len(vaes)):
149
- all_latents[i] = torch.cat(all_latents[i], dim=0)
150
- all_mse[i] = torch.stack(all_mse[i]).sum() / total
151
- all_psnr[i] = torch.stack(all_psnr[i]).sum() / total
152
- all_lpips[i] = torch.stack(all_lpips[i]).sum() / total
153
-
154
- logger.info(
155
- f" - {NAMES[i]}: MSE: {all_mse[i]:.3e}, PSNR: {all_psnr[i]:.4f}, "
156
- f"LPIPS: {all_lpips[i]:.4f}"
157
- )
158
-
159
- logger.info("End")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval.py → eval_imagenet.py RENAMED
File without changes
samples/sample_0.jpg ADDED

Git LFS Details

  • SHA256: d6d1cf55c86415afa68c4627f7349ff0c26a7a51f72587fc195228c710dd4e91
  • Pointer size: 130 Bytes
  • Size of remote file: 81.2 kB
samples/sample_1.jpg ADDED

Git LFS Details

  • SHA256: 69936edb0b610c7e688fe5806a30349f357f6fbca992d2ed53c1725e316c5b55
  • Pointer size: 130 Bytes
  • Size of remote file: 91.6 kB
samples/sample_2.jpg ADDED

Git LFS Details

  • SHA256: 4e1901885761cf14c8fc4bc42f2fccaebbfa16bbefaa07b19dc0809f386fc7da
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
samples/sample_decoded.jpg ADDED

Git LFS Details

  • SHA256: d6d1cf55c86415afa68c4627f7349ff0c26a7a51f72587fc195228c710dd4e91
  • Pointer size: 130 Bytes
  • Size of remote file: 81.2 kB
samples/sample_real.jpg ADDED

Git LFS Details

  • SHA256: a0b76cb257b6b9d0b97fcfbcd20c2f02c0151f7a5cc7c23b40bd0025eaf4413a
  • Pointer size: 130 Bytes
  • Size of remote file: 93 kB
train_sdxl_vae_gpt5.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ from accelerate import Accelerator
18
+ from PIL import Image, UnidentifiedImageError
19
+ from tqdm import tqdm
20
+ import bitsandbytes as bnb
21
+ import wandb
22
+ import lpips # pip install lpips
23
+ from collections import deque
24
+
25
+ # --------------------------- Параметры ---------------------------
26
+ ds_path = "/workspace/png"
27
+ project = "asymmetric_vae"
28
+ batch_size = 3
29
+ base_learning_rate = 6e-6
30
+ min_learning_rate = 1e-6
31
+ num_epochs = 8
32
+ sample_interval_share = 10
33
+ use_wandb = True
34
+ save_model = True
35
+ use_decay = True
36
+ asymmetric = True
37
+ optimizer_type = "adam8bit"
38
+ dtype = torch.float32
39
+ # model_resolution — то, что подавается в VAE (низкое разрешение)
40
+ model_resolution = 512 # бывший `resolution`
41
+ # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
42
+ high_resolution = 512
43
+ limit = 0
44
+ save_barrier = 1.03
45
+ warmup_percent = 0.01
46
+ percentile_clipping = 95
47
+ beta2 = 0.97
48
+ eps = 1e-6
49
+ clip_grad_norm = 1.0
50
+ mixed_precision = "no" # или "fp16"/"bf16" при поддержке
51
+ gradient_accumulation_steps = 5
52
+ generated_folder = "samples"
53
+ save_as = "asymmetric_vae_new"
54
+ num_workers = 0
55
+ device = None # accelerator задаст устройство
56
+
57
+ # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
58
+ # Итоговые доли в total loss (сумма = 1.0)
59
+ loss_ratios = {
60
+ "lpips": 0.85,
61
+ "edge": 0.05,
62
+ "mse": 0.05,
63
+ "mae": 0.05,
64
+ }
65
+ median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
66
+
67
+ # --------------------------- параметры препроцессинга ---------------------------
68
+ resize_long_side = 1280 # если None или 0 — ресайза не будет; рекомендовано 1024
69
+
70
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
71
+
72
+ accelerator = Accelerator(
73
+ mixed_precision=mixed_precision,
74
+ gradient_accumulation_steps=gradient_accumulation_steps
75
+ )
76
+ device = accelerator.device
77
+
78
+ # reproducibility
79
+ seed = int(datetime.now().strftime("%Y%m%d"))
80
+ torch.manual_seed(seed)
81
+ np.random.seed(seed)
82
+ random.seed(seed)
83
+
84
+ torch.backends.cudnn.benchmark = True
85
+
86
+ # --------------------------- WandB ---------------------------
87
+ if use_wandb and accelerator.is_main_process:
88
+ wandb.init(project=project, config={
89
+ "batch_size": batch_size,
90
+ "base_learning_rate": base_learning_rate,
91
+ "num_epochs": num_epochs,
92
+ "optimizer_type": optimizer_type,
93
+ "model_resolution": model_resolution,
94
+ "high_resolution": high_resolution,
95
+ "gradient_accumulation_steps": gradient_accumulation_steps,
96
+ })
97
+
98
+ # --------------------------- VAE ---------------------------
99
+ if model_resolution==high_resolution and not asymmetric:
100
+ vae = AutoencoderKL.from_pretrained(project).to(dtype)
101
+ else:
102
+ vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
103
+
104
+ # torch.compile (если доступно) — просто и без лишней логики
105
+ if hasattr(torch, "compile"):
106
+ try:
107
+ vae = torch.compile(vae)
108
+ except Exception as e:
109
+ print(f"[WARN] torch.compile failed: {e}")
110
+
111
+ # >>> Заморозка всех параметров, затем выборочная разморозка
112
+ for p in vae.parameters():
113
+ p.requires_grad = False
114
+
115
+ decoder = getattr(vae, "decoder", None)
116
+ if decoder is None:
117
+ raise RuntimeError("vae.decoder not found — не могу применить стратегию разморозки. Проверь структуру модели.")
118
+
119
+ unfrozen_param_names = []
120
+
121
+ if not hasattr(decoder, "up_blocks"):
122
+ raise RuntimeError("decoder.up_blocks не найдены — ожидается список блоков декодера.")
123
+
124
+ # >>> Размораживаем все up_blocks и mid_block (как было в твоём варианте start_idx=0)
125
+ n_up = len(decoder.up_blocks)
126
+ start_idx = 0
127
+ for idx in range(start_idx, n_up):
128
+ block = decoder.up_blocks[idx]
129
+ for name, p in block.named_parameters():
130
+ p.requires_grad = True
131
+ unfrozen_param_names.append(f"decoder.up_blocks.{idx}.{name}")
132
+
133
+ if hasattr(decoder, "mid_block"):
134
+ for name, p in decoder.mid_block.named_parameters():
135
+ p.requires_grad = True
136
+ unfrozen_param_names.append(f"decoder.mid_block.{name}")
137
+ else:
138
+ print("[WARN] decoder.mid_block не найден — mid_block не разморожен.")
139
+
140
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
141
+ for nm in unfrozen_param_names[:200]:
142
+ print(" ", nm)
143
+
144
+ # сохраняем trainable_module (get_param_groups будет учитывать p.requires_grad)
145
+ trainable_module = vae.decoder
146
+
147
+ # --------------------------- Custom PNG Dataset (only .png, skip corrupted) -----------
148
+ class PngFolderDataset(Dataset):
149
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
150
+ self.root_dir = root_dir
151
+ self.resolution = resolution
152
+ self.paths = []
153
+ # collect png files recursively
154
+ for root, _, files in os.walk(root_dir):
155
+ for fname in files:
156
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
157
+ self.paths.append(os.path.join(root, fname))
158
+ # optional limit
159
+ if limit:
160
+ self.paths = self.paths[:limit]
161
+ # verify images and keep only valid ones
162
+ valid = []
163
+ for p in self.paths:
164
+ try:
165
+ with Image.open(p) as im:
166
+ im.verify() # fast check for truncated/corrupted images
167
+ valid.append(p)
168
+ except (OSError, UnidentifiedImageError):
169
+ # skip corrupted image
170
+ continue
171
+ self.paths = valid
172
+ if len(self.paths) == 0:
173
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
174
+ # final shuffle for randomness
175
+ random.shuffle(self.paths)
176
+
177
+ def __len__(self):
178
+ return len(self.paths)
179
+
180
+ def __getitem__(self, idx):
181
+ p = self.paths[idx % len(self.paths)]
182
+ # open and convert to RGB; ensure file is closed promptly
183
+ with Image.open(p) as img:
184
+ img = img.convert("RGB")
185
+ # пережимаем длинную сторону до resize_long_side (Lanczos)
186
+ if not resize_long_side or resize_long_side <= 0:
187
+ return img
188
+ w, h = img.size
189
+ long = max(w, h)
190
+ if long <= resize_long_side:
191
+ return img
192
+ scale = resize_long_side / float(long)
193
+ new_w = int(round(w * scale))
194
+ new_h = int(round(h * scale))
195
+ return img.resize((new_w, new_h), Image.LANCZOS)
196
+
197
+ # --------------------------- Датасет и трансформы ---------------------------
198
+
199
+ def random_crop(img, sz):
200
+ w, h = img.size
201
+ if w < sz or h < sz:
202
+ img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
203
+ x = random.randint(0, max(1, img.width - sz))
204
+ y = random.randint(0, max(1, img.height - sz))
205
+ return img.crop((x, y, x + sz, y + sz))
206
+
207
+ tfm = transforms.Compose([
208
+ transforms.ToTensor(),
209
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
210
+ ])
211
+
212
+ # build dataset using high_resolution crops
213
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
214
+ if len(dataset) < batch_size:
215
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
216
+
217
+ # collate_fn кропит до high_resolution
218
+
219
+ def collate_fn(batch):
220
+ imgs = []
221
+ for img in batch: # img is PIL.Image
222
+ img = random_crop(img, high_resolution) # кропим high-res
223
+ imgs.append(tfm(img))
224
+ return torch.stack(imgs)
225
+
226
+ dataloader = DataLoader(
227
+ dataset,
228
+ batch_size=batch_size,
229
+ shuffle=True,
230
+ collate_fn=collate_fn,
231
+ num_workers=num_workers,
232
+ pin_memory=True,
233
+ drop_last=True
234
+ )
235
+
236
+ # --------------------------- Оптимизатор ---------------------------
237
+
238
+ def get_param_groups(module, weight_decay=0.001):
239
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
240
+ decay_params = []
241
+ no_decay_params = []
242
+ for n, p in module.named_parameters():
243
+ if not p.requires_grad:
244
+ continue
245
+ if any(nd in n for nd in no_decay):
246
+ no_decay_params.append(p)
247
+ else:
248
+ decay_params.append(p)
249
+ return [
250
+ {"params": decay_params, "weight_decay": weight_decay},
251
+ {"params": no_decay_params, "weight_decay": 0.0},
252
+ ]
253
+
254
+ def create_optimizer(name, param_groups):
255
+ if name == "adam8bit":
256
+ return bnb.optim.AdamW8bit(
257
+ param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps
258
+ )
259
+ raise ValueError(name)
260
+
261
+ param_groups = get_param_groups(trainable_module, weight_decay=0.001)
262
+ optimizer = create_optimizer(optimizer_type, param_groups)
263
+
264
+ # --------------------------- Подготовка Accelerate (вместе) ---------------------------
265
+
266
+ batches_per_epoch = len(dataloader) # число микро-батчей (dataloader steps)
267
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) # число optimizer.step() за эпоху
268
+ total_steps = steps_per_epoch * num_epochs
269
+
270
+
271
+ def lr_lambda(step):
272
+ if not use_decay:
273
+ return 1.0
274
+ x = float(step) / float(max(1, total_steps))
275
+ warmup = float(warmup_percent)
276
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
277
+ if x < warmup:
278
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
279
+ decay_ratio = (x - warmup) / (1.0 - warmup)
280
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
281
+
282
+ scheduler = LambdaLR(optimizer, lr_lambda)
283
+
284
+ # Подготовка
285
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
286
+
287
+ trainable_params = [p for p in vae.decoder.parameters() if p.requires_grad]
288
+
289
+ # --------------------------- LPIPS и вспомогательные функции ---------------------------
290
+ _lpips_net = None
291
+
292
+ def _get_lpips():
293
+ global _lpips_net
294
+ if _lpips_net is None:
295
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
296
+ return _lpips_net
297
+
298
+ # Собель для edge loss
299
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
300
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
301
+
302
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
303
+ # x: [B,C,H,W] в [-1,1]
304
+ C = x.shape[1]
305
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
306
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
307
+ gx = F.conv2d(x, kx, padding=1, groups=C)
308
+ gy = F.conv2d(x, ky, padding=1, groups=C)
309
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
310
+
311
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
312
+ class MedianLossNormalizer:
313
+ def __init__(self, desired_ratios: dict, window_steps: int):
314
+ # нормируем доли на случай, если сумма != 1
315
+ s = sum(desired_ratios.values())
316
+ self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
317
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
318
+ self.window = window_steps
319
+
320
+ def update_and_total(self, abs_losses: dict):
321
+ # Заполняем буферы фактическими АБСОЛЮТНЫМИ значениями лоссов
322
+ for k, v in abs_losses.items():
323
+ if k in self.buffers:
324
+ self.buffers[k].append(float(v.detach().cpu()))
325
+ # Медианы (устойчивые к выбросам)
326
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
327
+ # Вычисляем КОЭФФИЦИЕНТЫ как ratio_k / median_k — т.е. именно коэффициенты, а не значения
328
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
329
+ # Важно: при таких коэффициентах сумма (coeff_k * median_k) = сумма(ratio_k) = 1, т.е. масштаб стабилен
330
+ total = sum(coeffs[k] * abs_losses[k] for k in coeffs)
331
+ return total, coeffs, meds
332
+
333
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
334
+
335
+ # --------------------------- Сэмплы ---------------------------
336
+ @torch.no_grad()
337
+ def get_fixed_samples(n=3):
338
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
339
+ pil_imgs = [dataset[i] for i in idx] # dataset returns PIL.Image
340
+ tensors = []
341
+ for img in pil_imgs:
342
+ img = random_crop(img, high_resolution) # high-res fixed samples
343
+ tensors.append(tfm(img))
344
+ return torch.stack(tensors).to(accelerator.device, dtype)
345
+
346
+ fixed_samples = get_fixed_samples()
347
+
348
+ @torch.no_grad()
349
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
350
+ # img_tensor: [C,H,W] in [-1,1]
351
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
352
+ return Image.fromarray(arr)
353
+
354
+ @torch.no_grad()
355
+ def generate_and_save_samples(step=None):
356
+ try:
357
+ temp_vae = accelerator.unwrap_model(vae).eval()
358
+ lpips_net = _get_lpips()
359
+ with torch.no_grad():
360
+ # Готовим low-res вход для кодера ВСЕГДА под model_resolution
361
+ orig_high = fixed_samples # [B,C,H,W] в [-1,1]
362
+ orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
363
+ # dtype как у модели
364
+ model_dtype = next(temp_vae.parameters()).dtype
365
+ orig_low = orig_low.to(dtype=model_dtype)
366
+ # encode/decode
367
+ latents = temp_vae.encode(orig_low).latent_dist.mean
368
+ rec = temp_vae.decode(latents).sample
369
+
370
+ # Приводим spatial размер рекона к high-res (downsample для асимметричных VAE)
371
+ if rec.shape[-2:] != orig_high.shape[-2:]:
372
+ rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
373
+
374
+ # Сохраняем ПЕРВЫЙ семпл: real и decoded без номера шага в имени
375
+ first_real = _to_pil_uint8(orig_high[0])
376
+ first_dec = _to_pil_uint8(rec[0])
377
+ first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
378
+ first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
379
+
380
+ # Дополнительно сохраняем текущие реконструкции без номера шага (чтобы не плодить файлы — будут перезаписываться)
381
+ for i in range(rec.shape[0]):
382
+ _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
383
+
384
+ # LPIPS на полном изображении (high-res) — для лога
385
+ lpips_scores = []
386
+ for i in range(rec.shape[0]):
387
+ orig_full = orig_high[i:i+1].to(torch.float32)
388
+ rec_full = rec[i:i+1].to(torch.float32)
389
+ if rec_full.shape[-2:] != orig_full.shape[-2:]:
390
+ rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
391
+ lpips_val = lpips_net(orig_full, rec_full).item()
392
+ lpips_scores.append(lpips_val)
393
+ avg_lpips = float(np.mean(lpips_scores))
394
+
395
+ if use_wandb and accelerator.is_main_process:
396
+ wandb.log({
397
+ "lpips_mean": avg_lpips,
398
+ }, step=step)
399
+ finally:
400
+ gc.collect()
401
+ torch.cuda.empty_cache()
402
+
403
+ if accelerator.is_main_process and save_model:
404
+ print("Генерация сэмплов до старта обучения...")
405
+ generate_and_save_samples(0)
406
+
407
+ accelerator.wait_for_everyone()
408
+
409
+ # --------------------------- Тренировка ---------------------------
410
+
411
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
412
+ global_step = 0
413
+ min_loss = float("inf")
414
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
415
+
416
+ for epoch in range(num_epochs):
417
+ vae.train()
418
+ batch_losses = []
419
+ batch_grads = []
420
+ # Доп. трекинг по отдельным лоссам
421
+ track_losses = {k: [] for k in loss_ratios.keys()}
422
+ for imgs in dataloader:
423
+ with accelerator.accumulate(vae):
424
+ # imgs: high-res tensor from dataloader ([-1,1]), move to device
425
+ imgs = imgs.to(accelerator.device)
426
+
427
+ # ВСЕГДА даунсемплим вход под model_resolution для кодера
428
+ # Тупая железяка норовит все по своему сделать
429
+ if high_resolution != model_resolution:
430
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
431
+ else:
432
+ imgs_low = imgs
433
+
434
+ # ensure dtype matches model params to avoid float/half mismatch
435
+ model_dtype = next(vae.parameters()).dtype
436
+ if imgs_low.dtype != model_dtype:
437
+ imgs_low_model = imgs_low.to(dtype=model_dtype)
438
+ else:
439
+ imgs_low_model = imgs_low
440
+
441
+ # Encode/decode
442
+ latents = vae.encode(imgs_low_model).latent_dist.mean
443
+ rec = vae.decode(latents).sample # rec может быть увеличенным (асимметричный VAE)
444
+
445
+ # Приводим размер к high-res
446
+ if rec.shape[-2:] != imgs.shape[-2:]:
447
+ rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
448
+
449
+ # Лоссы считаем на high-res
450
+ rec_f32 = rec.to(torch.float32)
451
+ imgs_f32 = imgs.to(torch.float32)
452
+
453
+ # Отдельные лоссы
454
+ abs_losses = {
455
+ "mae": F.l1_loss(rec_f32, imgs_f32),
456
+ "mse": F.mse_loss(rec_f32, imgs_f32),
457
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
458
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
459
+ }
460
+
461
+ # Total с медианными КОЭФФИЦИЕНТАМИ
462
+ # Не надо так орать когда у тебя получилось понять мою идею
463
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
464
+
465
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
466
+ print("NaN/Inf loss – stopping")
467
+ raise RuntimeError("NaN/Inf loss")
468
+
469
+ accelerator.backward(total_loss)
470
+
471
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
472
+ if accelerator.sync_gradients:
473
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
474
+ optimizer.step()
475
+ scheduler.step()
476
+ optimizer.zero_grad(set_to_none=True)
477
+
478
+ global_step += 1
479
+ progress.update(1)
480
+
481
+ # --- Логирование ---
482
+ if accelerator.is_main_process:
483
+ try:
484
+ current_lr = optimizer.param_groups[0]["lr"]
485
+ except Exception:
486
+ current_lr = scheduler.get_last_lr()[0]
487
+
488
+ batch_losses.append(total_loss.detach().item())
489
+ batch_grads.append(float(grad_norm if isinstance(grad_norm, (float, int)) else grad_norm.cpu().item()))
490
+ for k, v in abs_losses.items():
491
+ track_losses[k].append(float(v.detach().item()))
492
+
493
+ if use_wandb and accelerator.sync_gradients:
494
+ log_dict = {
495
+ "total_loss": float(total_loss.detach().item()),
496
+ "learning_rate": current_lr,
497
+ "epoch": epoch,
498
+ "grad_norm": batch_grads[-1],
499
+ }
500
+ # добавляем отдельные лоссы
501
+ for k, v in abs_losses.items():
502
+ log_dict[f"loss_{k}"] = float(v.detach().item())
503
+ # логи коэффициентов и медиан
504
+ for k in coeffs:
505
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
506
+ log_dict[f"median_{k}"] = float(meds[k])
507
+ wandb.log(log_dict, step=global_step)
508
+
509
+ # периодические сэмплы и чекпоинты
510
+ if global_step > 0 and global_step % sample_interval == 0:
511
+ if accelerator.is_main_process:
512
+ generate_and_save_samples(global_step)
513
+ accelerator.wait_for_everyone()
514
+
515
+ # Средние по последним итерациям
516
+ n_micro = sample_interval * gradient_accumulation_steps
517
+ if len(batch_losses) >= n_micro:
518
+ avg_loss = float(np.mean(batch_losses[-n_micro:]))
519
+ else:
520
+ avg_loss = float(np.mean(batch_losses)) if batch_losses else float("nan")
521
+
522
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
523
+
524
+ if accelerator.is_main_process:
525
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
526
+ if save_model and avg_loss < min_loss * save_barrier:
527
+ min_loss = avg_loss
528
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
529
+ if use_wandb:
530
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
531
+
532
+ if accelerator.is_main_process:
533
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
534
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
535
+ if use_wandb:
536
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
537
+
538
+ # --------------------------- Финальное сохранение ---------------------------
539
+ if accelerator.is_main_process:
540
+ print("Training finished – saving final model")
541
+ if save_model:
542
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
543
+
544
+ accelerator.free_memory()
545
+ if torch.distributed.is_initialized():
546
+ torch.distributed.destroy_process_group()
547
+ print("Готово!")
train_sdxl_vae.py → train_sdxl_vae_my.py RENAMED
@@ -38,7 +38,7 @@ dtype = torch.float32
38
  # model_resolution — то, что подавается в VAE (низкое разрешение)
39
  model_resolution = 512 # бывший `resolution`
40
  # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
41
- high_resolution = 512 # >>> CHANGED: обучаемся на входах 1024 -> даунсемплим до 512 для модели
42
  limit = 0
43
  save_barrier = 1.03
44
  warmup_percent = 0.01
 
38
  # model_resolution — то, что подавается в VAE (низкое разрешение)
39
  model_resolution = 512 # бывший `resolution`
40
  # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
41
+ high_resolution = 1024
42
  limit = 0
43
  save_barrier = 1.03
44
  warmup_percent = 0.01