import os import torch import torch.nn.functional as F import lpips from PIL import Image, UnidentifiedImageError from tqdm import tqdm from torch.utils.data import Dataset, DataLoader from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop,ToPILImage from diffusers import AutoencoderKL, AsymmetricAutoencoderKL import random # --------------------------- Параметры --------------------------- DEVICE = "cuda" DTYPE = torch.float16 IMAGE_FOLDER = "/workspace/alchemist" #wget https://huggingface.co/datasets/AiArtLab/alchemist/resolve/main/alchemist.zip MIN_SIZE = 1280 CROP_SIZE = 512 BATCH_SIZE = 5 MAX_IMAGES = 100 NUM_WORKERS = 4 NUM_SAMPLES_TO_SAVE = 10 # Сколько примеров сохранить (0 - не сохранять) SAMPLES_FOLDER = "vaetest" # Список VAE для тестирования VAE_LIST = [ # ("stable-diffusion-v1-5/stable-diffusion-v1-5", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"), # ("cross-attention/asymmetric-autoencoder-kl-x-1-5", AsymmetricAutoencoderKL, "cross-attention/asymmetric-autoencoder-kl-x-1-5", None), ("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None), # ("AiArtLab/sdxs", AutoencoderKL, "AiArtLab/sdxs", "vae"), ("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None), # ("AiArtLab/sdxl_vae_asym", AsymmetricAutoencoderKL, "AiArtLab/sdxl_vae", "asymmetric_vae"), ("AiArtLab/sdxl_vae_asym_new", AsymmetricAutoencoderKL, "AiArtLab/sdxl_vae", "asymmetric_vae_new"), # ("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None), # ("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"), ] # --------------------------- Sobel Edge Detection --------------------------- # Определяем фильтры Собеля глобально _sobel_kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3) _sobel_ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3) def sobel_edges(x: torch.Tensor) -> torch.Tensor: """ Вычисляет карту границ с помощью оператора Собеля x: [B,C,H,W] в диапазоне [-1,1] Возвращает: [B,C,H,W] - магнитуда градиента """ C = x.shape[1] kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1) ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1) gx = F.conv2d(x, kx, padding=1, groups=C) gy = F.conv2d(x, ky, padding=1, groups=C) return torch.sqrt(gx * gx + gy * gy + 1e-12) def compute_edge_loss(real: torch.Tensor, fake: torch.Tensor) -> float: """ Вычисляет Edge Loss между реальным и сгенерированным изображением real, fake: [B,C,H,W] в диапазоне [0,1] Возвращает: скалярное значение loss """ # Конвертируем в [-1,1] для sobel_edges real_norm = real * 2 - 1 fake_norm = fake * 2 - 1 # Получаем карты границ edges_real = sobel_edges(real_norm) edges_fake = sobel_edges(fake_norm) # L1 loss между картами границ return F.l1_loss(edges_fake, edges_real).item() # --------------------------- Dataset --------------------------- class ImageFolderDataset(Dataset): def __init__(self, root_dir, extensions=('.png',), min_size=1024, crop_size=512, limit=None): self.root_dir = root_dir self.min_size = min_size self.crop_size = crop_size self.paths = [] print("Сканирование папки...") for root, _, files in os.walk(root_dir): for fname in files: if fname.lower().endswith(extensions): self.paths.append(os.path.join(root, fname)) if limit: self.paths = self.paths[:limit] print("Проверка изображений...") valid = [] for p in tqdm(self.paths, desc="Проверка"): try: with Image.open(p) as im: im.verify() valid.append(p) except: continue self.paths = valid if len(self.paths) == 0: raise RuntimeError(f"Не найдено валидных изображений в {root_dir}") random.shuffle(self.paths) print(f"Найдено {len(self.paths)} изображений") self.transform = Compose([ Resize(min_size, interpolation=Image.LANCZOS), CenterCrop(crop_size), ToTensor(), ]) def __len__(self): return len(self.paths) def __getitem__(self, idx): path = self.paths[idx] with Image.open(path) as img: img = img.convert("RGB") return self.transform(img) # --------------------------- Функции --------------------------- def process(x): return x * 2 - 1 def deprocess(x): return x * 0.5 + 0.5 def _sanitize_name(name: str) -> str: return name.replace('/', '_').replace('-', '_') # --------------------------- Основной код --------------------------- if __name__ == "__main__": if NUM_SAMPLES_TO_SAVE > 0: os.makedirs(SAMPLES_FOLDER, exist_ok=True) dataset = ImageFolderDataset( IMAGE_FOLDER, extensions=('.png',), min_size=MIN_SIZE, crop_size=CROP_SIZE, limit=MAX_IMAGES ) dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, drop_last=False ) lpips_net = lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False) print("\nЗагрузка VAE моделей...") vaes = [] names = [] for name, vae_class, model_path, subfolder in VAE_LIST: try: print(f" Загружаю {name}...") # Исправлена загрузка для variant if "sdxs" in model_path: vae = vae_class.from_pretrained(model_path, subfolder=subfolder, variant="fp16") else: vae = vae_class.from_pretrained(model_path, subfolder=subfolder) vae = vae.to(DEVICE, DTYPE).eval() vaes.append(vae) names.append(name) except Exception as e: print(f" ❌ Ошибка загрузки {name}: {e}") print("\nОценка метрик...") results = {name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "count": 0} for name in names} to_pil = ToPILImage() # >>>>>>>> ОСНОВНЫЕ ИЗМЕНЕНИЯ ЗДЕСЬ (KISS) <<<<<<<< with torch.no_grad(): images_saved = 0 # считаем именно КОЛ-ВО ИЗОБРАЖЕНИЙ, а не сохранённых файлов for batch in tqdm(dataloader, desc="Обработка батчей"): batch = batch.to(DEVICE) # [B,3,H,W] в [0,1] test_inp = process(batch).to(DTYPE) # [-1,1] для энкодера # 1) считаем реконструкции для всех VAE на весь батч recon_list = [] for vae in vaes: latent = vae.encode(test_inp).latent_dist.mode() dec = vae.decode(latent).sample.float() # [-1,1] (как правило) recon = deprocess(dec).clamp(0.0, 1.0) # -> [0,1], clamp убирает артефакты recon_list.append(recon) # 2) обновляем метрики (по каждой VAE) for recon, name in zip(recon_list, names): for i in range(batch.shape[0]): img_orig = batch[i:i+1] img_recon = recon[i:i+1] mse = F.mse_loss(img_orig, img_recon).item() psnr = 10 * torch.log10(1 / torch.tensor(mse)).item() lpips_val = lpips_net(img_orig, img_recon, normalize=True).mean().item() edge_loss = compute_edge_loss(img_orig, img_recon) results[name]["mse"] += mse results[name]["psnr"] += psnr results[name]["lpips"] += lpips_val results[name]["edge"] += edge_loss results[name]["count"] += 1 # 3) сохраняем ровно NUM_SAMPLES_TO_SAVE изображений (orig + все VAE + общий коллаж) if NUM_SAMPLES_TO_SAVE > 0: for i in range(batch.shape[0]): if images_saved >= NUM_SAMPLES_TO_SAVE: break idx_str = f"{images_saved + 1:03d}" # original orig_pil = to_pil(batch[i].detach().float().cpu()) orig_pil.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_orig.png")) # per-VAE decodes tiles = [orig_pil] for recon, name in zip(recon_list, names): recon_pil = to_pil(recon[i].detach().cpu()) recon_pil.save(os.path.join( SAMPLES_FOLDER, f"{idx_str}_decoded_{_sanitize_name(name)}.png" )) tiles.append(recon_pil) # общий коллаж: [orig | vae1 | vae2 | ...] collage_w = CROP_SIZE * len(tiles) collage_h = CROP_SIZE collage = Image.new("RGB", (collage_w, collage_h)) x = 0 for tile in tiles: collage.paste(tile, (x, 0)) x += CROP_SIZE collage.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_all.png")) images_saved += 1 # Усреднение результатов for name in names: count = results[name]["count"] results[name]["mse"] /= count results[name]["psnr"] /= count results[name]["lpips"] /= count results[name]["edge"] /= count # Вывод абсолютных значений print("\n=== Абсолютные значения ===") for name in names: print(f"{name:30s}: MSE: {results[name]['mse']:.3e}, PSNR: {results[name]['psnr']:.4f}, " f"LPIPS: {results[name]['lpips']:.4f}, Edge: {results[name]['edge']:.4f}") # Вывод таблицы с процентами print("\n=== Сравнение с первой моделью (%) ===") print(f"| {'Модель':30s} | {'MSE':>10s} | {'PSNR':>10s} | {'LPIPS':>10s} | {'Edge':>10s} |") print(f"|{'-'*32}|{'-'*12}|{'-'*12}|{'-'*12}|{'-'*12}|") baseline = names[0] for name in names: # Для MSE, LPIPS и Edge: меньше = лучше, поэтому инвертируем mse_pct = (results[baseline]["mse"] / results[name]["mse"]) * 100 # Для PSNR: больше = лучше psnr_pct = (results[name]["psnr"] / results[baseline]["psnr"]) * 100 # Для LPIPS и Edge: меньше = лучше lpips_pct = (results[baseline]["lpips"] / results[name]["lpips"]) * 100 edge_pct = (results[baseline]["edge"] / results[name]["edge"]) * 100 if name == baseline: print(f"| {name:30s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} |") else: print(f"| {name:30s} | {f'{mse_pct:.1f}%':>10s} | {f'{psnr_pct:.1f}%':>10s} | " f"{f'{lpips_pct:.1f}%':>10s} | {f'{edge_pct:.1f}%':>10s} |") print("\n✅ Готово!")