sdxl_vae / eval_alchemist.py
recoilme's picture
remove asymmetric
4744909
raw
history blame
12 kB
import os
import torch
import torch.nn.functional as F
import lpips
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop,ToPILImage
from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
import random
# --------------------------- Параметры ---------------------------
DEVICE = "cuda"
DTYPE = torch.float16
IMAGE_FOLDER = "/workspace/alchemist" #wget https://huggingface.co/datasets/AiArtLab/alchemist/resolve/main/alchemist.zip
MIN_SIZE = 1280
CROP_SIZE = 512
BATCH_SIZE = 5
MAX_IMAGES = 100
NUM_WORKERS = 4
NUM_SAMPLES_TO_SAVE = 10 # Сколько примеров сохранить (0 - не сохранять)
SAMPLES_FOLDER = "vaetest"
# Список VAE для тестирования
VAE_LIST = [
# ("stable-diffusion-v1-5/stable-diffusion-v1-5", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
# ("cross-attention/asymmetric-autoencoder-kl-x-1-5", AsymmetricAutoencoderKL, "cross-attention/asymmetric-autoencoder-kl-x-1-5", None),
("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
# ("AiArtLab/sdxs", AutoencoderKL, "AiArtLab/sdxs", "vae"),
("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None),
# ("AiArtLab/sdxl_vae_asym", AsymmetricAutoencoderKL, "AiArtLab/sdxl_vae", "asymmetric_vae"),
("AiArtLab/sdxl_vae_asym_new", AsymmetricAutoencoderKL, "AiArtLab/sdxl_vae", "asymmetric_vae_new"),
# ("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None),
# ("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
]
# --------------------------- Sobel Edge Detection ---------------------------
# Определяем фильтры Собеля глобально
_sobel_kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
_sobel_ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
def sobel_edges(x: torch.Tensor) -> torch.Tensor:
"""
Вычисляет карту границ с помощью оператора Собеля
x: [B,C,H,W] в диапазоне [-1,1]
Возвращает: [B,C,H,W] - магнитуда градиента
"""
C = x.shape[1]
kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
gx = F.conv2d(x, kx, padding=1, groups=C)
gy = F.conv2d(x, ky, padding=1, groups=C)
return torch.sqrt(gx * gx + gy * gy + 1e-12)
def compute_edge_loss(real: torch.Tensor, fake: torch.Tensor) -> float:
"""
Вычисляет Edge Loss между реальным и сгенерированным изображением
real, fake: [B,C,H,W] в диапазоне [0,1]
Возвращает: скалярное значение loss
"""
# Конвертируем в [-1,1] для sobel_edges
real_norm = real * 2 - 1
fake_norm = fake * 2 - 1
# Получаем карты границ
edges_real = sobel_edges(real_norm)
edges_fake = sobel_edges(fake_norm)
# L1 loss между картами границ
return F.l1_loss(edges_fake, edges_real).item()
# --------------------------- Dataset ---------------------------
class ImageFolderDataset(Dataset):
def __init__(self, root_dir, extensions=('.png',), min_size=1024, crop_size=512, limit=None):
self.root_dir = root_dir
self.min_size = min_size
self.crop_size = crop_size
self.paths = []
print("Сканирование папки...")
for root, _, files in os.walk(root_dir):
for fname in files:
if fname.lower().endswith(extensions):
self.paths.append(os.path.join(root, fname))
if limit:
self.paths = self.paths[:limit]
print("Проверка изображений...")
valid = []
for p in tqdm(self.paths, desc="Проверка"):
try:
with Image.open(p) as im:
im.verify()
valid.append(p)
except:
continue
self.paths = valid
if len(self.paths) == 0:
raise RuntimeError(f"Не найдено валидных изображений в {root_dir}")
random.shuffle(self.paths)
print(f"Найдено {len(self.paths)} изображений")
self.transform = Compose([
Resize(min_size, interpolation=Image.LANCZOS),
CenterCrop(crop_size),
ToTensor(),
])
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
with Image.open(path) as img:
img = img.convert("RGB")
return self.transform(img)
# --------------------------- Функции ---------------------------
def process(x):
return x * 2 - 1
def deprocess(x):
return x * 0.5 + 0.5
def _sanitize_name(name: str) -> str:
return name.replace('/', '_').replace('-', '_')
# --------------------------- Основной код ---------------------------
if __name__ == "__main__":
if NUM_SAMPLES_TO_SAVE > 0:
os.makedirs(SAMPLES_FOLDER, exist_ok=True)
dataset = ImageFolderDataset(
IMAGE_FOLDER,
extensions=('.png',),
min_size=MIN_SIZE,
crop_size=CROP_SIZE,
limit=MAX_IMAGES
)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
drop_last=False
)
lpips_net = lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False)
print("\nЗагрузка VAE моделей...")
vaes = []
names = []
for name, vae_class, model_path, subfolder in VAE_LIST:
try:
print(f" Загружаю {name}...")
# Исправлена загрузка для variant
if "sdxs" in model_path:
vae = vae_class.from_pretrained(model_path, subfolder=subfolder, variant="fp16")
else:
vae = vae_class.from_pretrained(model_path, subfolder=subfolder)
vae = vae.to(DEVICE, DTYPE).eval()
vaes.append(vae)
names.append(name)
except Exception as e:
print(f" ❌ Ошибка загрузки {name}: {e}")
print("\nОценка метрик...")
results = {name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "count": 0} for name in names}
to_pil = ToPILImage()
# >>>>>>>> ОСНОВНЫЕ ИЗМЕНЕНИЯ ЗДЕСЬ (KISS) <<<<<<<<
with torch.no_grad():
images_saved = 0 # считаем именно КОЛ-ВО ИЗОБРАЖЕНИЙ, а не сохранённых файлов
for batch in tqdm(dataloader, desc="Обработка батчей"):
batch = batch.to(DEVICE) # [B,3,H,W] в [0,1]
test_inp = process(batch).to(DTYPE) # [-1,1] для энкодера
# 1) считаем реконструкции для всех VAE на весь батч
recon_list = []
for vae in vaes:
latent = vae.encode(test_inp).latent_dist.mode()
dec = vae.decode(latent).sample.float() # [-1,1] (как правило)
recon = deprocess(dec).clamp(0.0, 1.0) # -> [0,1], clamp убирает артефакты
recon_list.append(recon)
# 2) обновляем метрики (по каждой VAE)
for recon, name in zip(recon_list, names):
for i in range(batch.shape[0]):
img_orig = batch[i:i+1]
img_recon = recon[i:i+1]
mse = F.mse_loss(img_orig, img_recon).item()
psnr = 10 * torch.log10(1 / torch.tensor(mse)).item()
lpips_val = lpips_net(img_orig, img_recon, normalize=True).mean().item()
edge_loss = compute_edge_loss(img_orig, img_recon)
results[name]["mse"] += mse
results[name]["psnr"] += psnr
results[name]["lpips"] += lpips_val
results[name]["edge"] += edge_loss
results[name]["count"] += 1
# 3) сохраняем ровно NUM_SAMPLES_TO_SAVE изображений (orig + все VAE + общий коллаж)
if NUM_SAMPLES_TO_SAVE > 0:
for i in range(batch.shape[0]):
if images_saved >= NUM_SAMPLES_TO_SAVE:
break
idx_str = f"{images_saved + 1:03d}"
# original
orig_pil = to_pil(batch[i].detach().float().cpu())
orig_pil.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_orig.png"))
# per-VAE decodes
tiles = [orig_pil]
for recon, name in zip(recon_list, names):
recon_pil = to_pil(recon[i].detach().cpu())
recon_pil.save(os.path.join(
SAMPLES_FOLDER, f"{idx_str}_decoded_{_sanitize_name(name)}.png"
))
tiles.append(recon_pil)
# общий коллаж: [orig | vae1 | vae2 | ...]
collage_w = CROP_SIZE * len(tiles)
collage_h = CROP_SIZE
collage = Image.new("RGB", (collage_w, collage_h))
x = 0
for tile in tiles:
collage.paste(tile, (x, 0))
x += CROP_SIZE
collage.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_all.png"))
images_saved += 1
# Усреднение результатов
for name in names:
count = results[name]["count"]
results[name]["mse"] /= count
results[name]["psnr"] /= count
results[name]["lpips"] /= count
results[name]["edge"] /= count
# Вывод абсолютных значений
print("\n=== Абсолютные значения ===")
for name in names:
print(f"{name:30s}: MSE: {results[name]['mse']:.3e}, PSNR: {results[name]['psnr']:.4f}, "
f"LPIPS: {results[name]['lpips']:.4f}, Edge: {results[name]['edge']:.4f}")
# Вывод таблицы с процентами
print("\n=== Сравнение с первой моделью (%) ===")
print(f"| {'Модель':30s} | {'MSE':>10s} | {'PSNR':>10s} | {'LPIPS':>10s} | {'Edge':>10s} |")
print(f"|{'-'*32}|{'-'*12}|{'-'*12}|{'-'*12}|{'-'*12}|")
baseline = names[0]
for name in names:
# Для MSE, LPIPS и Edge: меньше = лучше, поэтому инвертируем
mse_pct = (results[baseline]["mse"] / results[name]["mse"]) * 100
# Для PSNR: больше = лучше
psnr_pct = (results[name]["psnr"] / results[baseline]["psnr"]) * 100
# Для LPIPS и Edge: меньше = лучше
lpips_pct = (results[baseline]["lpips"] / results[name]["lpips"]) * 100
edge_pct = (results[baseline]["edge"] / results[name]["edge"]) * 100
if name == baseline:
print(f"| {name:30s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} |")
else:
print(f"| {name:30s} | {f'{mse_pct:.1f}%':>10s} | {f'{psnr_pct:.1f}%':>10s} | "
f"{f'{lpips_pct:.1f}%':>10s} | {f'{edge_pct:.1f}%':>10s} |")
print("\n✅ Готово!")