|
import os |
|
import torch |
|
import torch.nn.functional as F |
|
import lpips |
|
from PIL import Image, UnidentifiedImageError |
|
from tqdm import tqdm |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop |
|
from diffusers import AutoencoderKL, AsymmetricAutoencoderKL |
|
import random |
|
|
|
|
|
DEVICE = "cuda" |
|
DTYPE = torch.float16 |
|
IMAGE_FOLDER = "/workspace/alchemist" |
|
MIN_SIZE = 1280 |
|
CROP_SIZE = 512 |
|
BATCH_SIZE = 4 |
|
MAX_IMAGES = None |
|
NUM_WORKERS = 4 |
|
|
|
|
|
VAE_LIST = [ |
|
("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None), |
|
("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None), |
|
("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None), |
|
("AiArtLab/sdxl_vae_asym", AsymmetricAutoencoderKL, "AiArtLab/sdxl_vae", "asymmetric_vae"), |
|
("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"), |
|
] |
|
|
|
|
|
class ImageFolderDataset(Dataset): |
|
def __init__(self, root_dir, extensions=('.png',), min_size=1024, crop_size=512, limit=None): |
|
self.root_dir = root_dir |
|
self.min_size = min_size |
|
self.crop_size = crop_size |
|
self.paths = [] |
|
|
|
|
|
print("Сканирование папки...") |
|
for root, _, files in os.walk(root_dir): |
|
for fname in files: |
|
if fname.lower().endswith(extensions): |
|
self.paths.append(os.path.join(root, fname)) |
|
|
|
|
|
if limit: |
|
self.paths = self.paths[:limit] |
|
|
|
|
|
print("Проверка изображений...") |
|
valid = [] |
|
for p in tqdm(self.paths, desc="Проверка"): |
|
try: |
|
with Image.open(p) as im: |
|
im.verify() |
|
valid.append(p) |
|
except: |
|
continue |
|
self.paths = valid |
|
|
|
if len(self.paths) == 0: |
|
raise RuntimeError(f"Не найдено валидных изображений в {root_dir}") |
|
|
|
|
|
random.shuffle(self.paths) |
|
print(f"Найдено {len(self.paths)} изображений") |
|
|
|
|
|
self.transform = Compose([ |
|
Resize(min_size, interpolation=Image.LANCZOS), |
|
CenterCrop(crop_size), |
|
ToTensor(), |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.paths) |
|
|
|
def __getitem__(self, idx): |
|
path = self.paths[idx] |
|
with Image.open(path) as img: |
|
img = img.convert("RGB") |
|
return self.transform(img) |
|
|
|
|
|
def process(x): |
|
return x * 2 - 1 |
|
|
|
def deprocess(x): |
|
return x * 0.5 + 0.5 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
dataset = ImageFolderDataset( |
|
IMAGE_FOLDER, |
|
extensions=('.png',), |
|
min_size=MIN_SIZE, |
|
crop_size=CROP_SIZE, |
|
limit=MAX_IMAGES |
|
) |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=BATCH_SIZE, |
|
shuffle=False, |
|
num_workers=NUM_WORKERS, |
|
pin_memory=True, |
|
drop_last=False |
|
) |
|
|
|
|
|
lpips_net = lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False) |
|
|
|
|
|
print("\nЗагрузка VAE моделей...") |
|
vaes = [] |
|
names = [] |
|
|
|
for name, vae_class, model_path, subfolder in VAE_LIST: |
|
try: |
|
print(f" Загружаю {name}...") |
|
vae = vae_class.from_pretrained(model_path, subfolder=subfolder) |
|
vae = vae.to(DEVICE, DTYPE).eval() |
|
vaes.append(vae) |
|
names.append(name) |
|
except Exception as e: |
|
print(f" ❌ Ошибка загрузки {name}: {e}") |
|
|
|
|
|
print("\nОценка метрик...") |
|
results = {name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "count": 0} for name in names} |
|
|
|
with torch.no_grad(): |
|
for batch in tqdm(dataloader, desc="Обработка батчей"): |
|
batch = batch.to(DEVICE) |
|
test_inp = process(batch).to(DTYPE) |
|
|
|
for vae, name in zip(vaes, names): |
|
|
|
latent = vae.encode(test_inp).latent_dist.mode() |
|
recon = deprocess(vae.decode(latent).sample.float()) |
|
|
|
|
|
for i in range(batch.shape[0]): |
|
img_orig = batch[i:i+1] |
|
img_recon = recon[i:i+1] |
|
|
|
mse = F.mse_loss(img_orig, img_recon).item() |
|
psnr = 10 * torch.log10(1 / torch.tensor(mse)).item() |
|
lpips_val = lpips_net(img_orig, img_recon, normalize=True).mean().item() |
|
|
|
results[name]["mse"] += mse |
|
results[name]["psnr"] += psnr |
|
results[name]["lpips"] += lpips_val |
|
results[name]["count"] += 1 |
|
|
|
|
|
for name in names: |
|
count = results[name]["count"] |
|
results[name]["mse"] /= count |
|
results[name]["psnr"] /= count |
|
results[name]["lpips"] /= count |
|
|
|
|
|
print("\n=== Абсолютные значения ===") |
|
for name in names: |
|
print(f"{name:30s}: MSE: {results[name]['mse']:.3e}, PSNR: {results[name]['psnr']:.4f}, LPIPS: {results[name]['lpips']:.4f}") |
|
|
|
|
|
print("\n=== Сравнение с первой моделью (%) ===") |
|
print(f"| {'Модель':30s} | {'MSE':>10s} | {'PSNR':>10s} | {'LPIPS':>10s} |") |
|
print(f"|{'-'*32}|{'-'*12}|{'-'*12}|{'-'*12}|") |
|
|
|
baseline = names[0] |
|
for name in names: |
|
mse_pct = (results[baseline]["mse"] / results[name]["mse"]) * 100 |
|
psnr_pct = (results[name]["psnr"] / results[baseline]["psnr"]) * 100 |
|
lpips_pct = (results[baseline]["lpips"] / results[name]["lpips"]) * 100 |
|
|
|
if name == baseline: |
|
print(f"| {name:30s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} |") |
|
else: |
|
print(f"| {name:30s} | {f'{mse_pct:.1f}%':>10s} | {f'{psnr_pct:.1f}%':>10s} | {f'{lpips_pct:.1f}%':>10s} |") |
|
|
|
print("\n✅ Готово!") |