sdxl_vae / eval_alchemist.py
recoilme's picture
asym
bf7a967
import os
import torch
import torch.nn.functional as F
import lpips
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop
from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
import random
# --------------------------- Параметры ---------------------------
DEVICE = "cuda"
DTYPE = torch.float16
IMAGE_FOLDER = "/workspace/alchemist"
MIN_SIZE = 1280
CROP_SIZE = 512
BATCH_SIZE = 4 # можно увеличить для ускорения
MAX_IMAGES = None
NUM_WORKERS = 4 # параллельная загрузка
# Список VAE для тестирования
VAE_LIST = [
("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None),
("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None),
("AiArtLab/sdxl_vae_asym", AsymmetricAutoencoderKL, "AiArtLab/sdxl_vae", "asymmetric_vae"),
("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
]
# --------------------------- Dataset ---------------------------
class ImageFolderDataset(Dataset):
def __init__(self, root_dir, extensions=('.png',), min_size=1024, crop_size=512, limit=None):
self.root_dir = root_dir
self.min_size = min_size
self.crop_size = crop_size
self.paths = []
# Собираем пути к файлам
print("Сканирование папки...")
for root, _, files in os.walk(root_dir):
for fname in files:
if fname.lower().endswith(extensions):
self.paths.append(os.path.join(root, fname))
# Ограничение количества
if limit:
self.paths = self.paths[:limit]
# Быстрая проверка валидности (опционально, можно убрать для скорости)
print("Проверка изображений...")
valid = []
for p in tqdm(self.paths, desc="Проверка"):
try:
with Image.open(p) as im:
im.verify()
valid.append(p)
except:
continue
self.paths = valid
if len(self.paths) == 0:
raise RuntimeError(f"Не найдено валидных изображений в {root_dir}")
# Перемешиваем для случайности
random.shuffle(self.paths)
print(f"Найдено {len(self.paths)} изображений")
# Трансформации
self.transform = Compose([
Resize(min_size, interpolation=Image.LANCZOS),
CenterCrop(crop_size),
ToTensor(),
])
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
with Image.open(path) as img:
img = img.convert("RGB")
return self.transform(img)
# --------------------------- Функции ---------------------------
def process(x):
return x * 2 - 1
def deprocess(x):
return x * 0.5 + 0.5
# --------------------------- Основной код ---------------------------
if __name__ == "__main__":
# Создаем датасет и загрузчик
dataset = ImageFolderDataset(
IMAGE_FOLDER,
extensions=('.png',),
min_size=MIN_SIZE,
crop_size=CROP_SIZE,
limit=MAX_IMAGES
)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=False, # уже перемешали в датасете
num_workers=NUM_WORKERS,
pin_memory=True,
drop_last=False
)
# Инициализация LPIPS
lpips_net = lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False)
# Загрузка VAE моделей
print("\nЗагрузка VAE моделей...")
vaes = []
names = []
for name, vae_class, model_path, subfolder in VAE_LIST:
try:
print(f" Загружаю {name}...")
vae = vae_class.from_pretrained(model_path, subfolder=subfolder)
vae = vae.to(DEVICE, DTYPE).eval()
vaes.append(vae)
names.append(name)
except Exception as e:
print(f" ❌ Ошибка загрузки {name}: {e}")
# Оценка метрик
print("\nОценка метрик...")
results = {name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "count": 0} for name in names}
with torch.no_grad():
for batch in tqdm(dataloader, desc="Обработка батчей"):
batch = batch.to(DEVICE)
test_inp = process(batch).to(DTYPE)
for vae, name in zip(vaes, names):
# Encode/decode
latent = vae.encode(test_inp).latent_dist.mode()
recon = deprocess(vae.decode(latent).sample.float())
# Метрики для батча
for i in range(batch.shape[0]):
img_orig = batch[i:i+1]
img_recon = recon[i:i+1]
mse = F.mse_loss(img_orig, img_recon).item()
psnr = 10 * torch.log10(1 / torch.tensor(mse)).item()
lpips_val = lpips_net(img_orig, img_recon, normalize=True).mean().item()
results[name]["mse"] += mse
results[name]["psnr"] += psnr
results[name]["lpips"] += lpips_val
results[name]["count"] += 1
# Усреднение результатов
for name in names:
count = results[name]["count"]
results[name]["mse"] /= count
results[name]["psnr"] /= count
results[name]["lpips"] /= count
# Вывод абсолютных значений
print("\n=== Абсолютные значения ===")
for name in names:
print(f"{name:30s}: MSE: {results[name]['mse']:.3e}, PSNR: {results[name]['psnr']:.4f}, LPIPS: {results[name]['lpips']:.4f}")
# Вывод таблицы с процентами
print("\n=== Сравнение с первой моделью (%) ===")
print(f"| {'Модель':30s} | {'MSE':>10s} | {'PSNR':>10s} | {'LPIPS':>10s} |")
print(f"|{'-'*32}|{'-'*12}|{'-'*12}|{'-'*12}|")
baseline = names[0]
for name in names:
mse_pct = (results[baseline]["mse"] / results[name]["mse"]) * 100
psnr_pct = (results[name]["psnr"] / results[baseline]["psnr"]) * 100
lpips_pct = (results[baseline]["lpips"] / results[name]["lpips"]) * 100
if name == baseline:
print(f"| {name:30s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} |")
else:
print(f"| {name:30s} | {f'{mse_pct:.1f}%':>10s} | {f'{psnr_pct:.1f}%':>10s} | {f'{lpips_pct:.1f}%':>10s} |")
print("\n✅ Готово!")