|
import os |
|
import torch |
|
import torch.nn.functional as F |
|
import lpips |
|
from PIL import Image, UnidentifiedImageError |
|
from tqdm import tqdm |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop,ToPILImage |
|
from diffusers import AutoencoderKL, AsymmetricAutoencoderKL |
|
import random |
|
|
|
|
|
DEVICE = "cuda" |
|
DTYPE = torch.float16 |
|
IMAGE_FOLDER = "/workspace/alchemist" |
|
MIN_SIZE = 1280 |
|
CROP_SIZE = 512 |
|
BATCH_SIZE = 5 |
|
MAX_IMAGES = 100 |
|
NUM_WORKERS = 4 |
|
NUM_SAMPLES_TO_SAVE = 10 |
|
SAMPLES_FOLDER = "vaetest" |
|
|
|
|
|
VAE_LIST = [ |
|
|
|
|
|
|
|
("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None), |
|
|
|
("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None), |
|
|
|
("AiArtLab/sdxl_vae_asym_new", AsymmetricAutoencoderKL, "AiArtLab/sdxl_vae", "asymmetric_vae_new"), |
|
|
|
|
|
] |
|
|
|
|
|
|
|
_sobel_kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3) |
|
_sobel_ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3) |
|
|
|
def sobel_edges(x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Вычисляет карту границ с помощью оператора Собеля |
|
x: [B,C,H,W] в диапазоне [-1,1] |
|
Возвращает: [B,C,H,W] - магнитуда градиента |
|
""" |
|
C = x.shape[1] |
|
kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1) |
|
ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1) |
|
gx = F.conv2d(x, kx, padding=1, groups=C) |
|
gy = F.conv2d(x, ky, padding=1, groups=C) |
|
return torch.sqrt(gx * gx + gy * gy + 1e-12) |
|
|
|
def compute_edge_loss(real: torch.Tensor, fake: torch.Tensor) -> float: |
|
""" |
|
Вычисляет Edge Loss между реальным и сгенерированным изображением |
|
real, fake: [B,C,H,W] в диапазоне [0,1] |
|
Возвращает: скалярное значение loss |
|
""" |
|
|
|
real_norm = real * 2 - 1 |
|
fake_norm = fake * 2 - 1 |
|
|
|
|
|
edges_real = sobel_edges(real_norm) |
|
edges_fake = sobel_edges(fake_norm) |
|
|
|
|
|
return F.l1_loss(edges_fake, edges_real).item() |
|
|
|
|
|
class ImageFolderDataset(Dataset): |
|
def __init__(self, root_dir, extensions=('.png',), min_size=1024, crop_size=512, limit=None): |
|
self.root_dir = root_dir |
|
self.min_size = min_size |
|
self.crop_size = crop_size |
|
self.paths = [] |
|
|
|
print("Сканирование папки...") |
|
for root, _, files in os.walk(root_dir): |
|
for fname in files: |
|
if fname.lower().endswith(extensions): |
|
self.paths.append(os.path.join(root, fname)) |
|
|
|
if limit: |
|
self.paths = self.paths[:limit] |
|
|
|
print("Проверка изображений...") |
|
valid = [] |
|
for p in tqdm(self.paths, desc="Проверка"): |
|
try: |
|
with Image.open(p) as im: |
|
im.verify() |
|
valid.append(p) |
|
except: |
|
continue |
|
self.paths = valid |
|
|
|
if len(self.paths) == 0: |
|
raise RuntimeError(f"Не найдено валидных изображений в {root_dir}") |
|
|
|
random.shuffle(self.paths) |
|
print(f"Найдено {len(self.paths)} изображений") |
|
|
|
self.transform = Compose([ |
|
Resize(min_size, interpolation=Image.LANCZOS), |
|
CenterCrop(crop_size), |
|
ToTensor(), |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.paths) |
|
|
|
def __getitem__(self, idx): |
|
path = self.paths[idx] |
|
with Image.open(path) as img: |
|
img = img.convert("RGB") |
|
return self.transform(img) |
|
|
|
|
|
def process(x): |
|
return x * 2 - 1 |
|
|
|
def deprocess(x): |
|
return x * 0.5 + 0.5 |
|
|
|
def _sanitize_name(name: str) -> str: |
|
return name.replace('/', '_').replace('-', '_') |
|
|
|
|
|
if __name__ == "__main__": |
|
if NUM_SAMPLES_TO_SAVE > 0: |
|
os.makedirs(SAMPLES_FOLDER, exist_ok=True) |
|
|
|
dataset = ImageFolderDataset( |
|
IMAGE_FOLDER, |
|
extensions=('.png',), |
|
min_size=MIN_SIZE, |
|
crop_size=CROP_SIZE, |
|
limit=MAX_IMAGES |
|
) |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=BATCH_SIZE, |
|
shuffle=False, |
|
num_workers=NUM_WORKERS, |
|
pin_memory=True, |
|
drop_last=False |
|
) |
|
|
|
lpips_net = lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False) |
|
|
|
print("\nЗагрузка VAE моделей...") |
|
vaes = [] |
|
names = [] |
|
|
|
for name, vae_class, model_path, subfolder in VAE_LIST: |
|
try: |
|
print(f" Загружаю {name}...") |
|
|
|
if "sdxs" in model_path: |
|
vae = vae_class.from_pretrained(model_path, subfolder=subfolder, variant="fp16") |
|
else: |
|
vae = vae_class.from_pretrained(model_path, subfolder=subfolder) |
|
vae = vae.to(DEVICE, DTYPE).eval() |
|
vaes.append(vae) |
|
names.append(name) |
|
except Exception as e: |
|
print(f" ❌ Ошибка загрузки {name}: {e}") |
|
|
|
print("\nОценка метрик...") |
|
results = {name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "count": 0} for name in names} |
|
|
|
to_pil = ToPILImage() |
|
|
|
|
|
with torch.no_grad(): |
|
images_saved = 0 |
|
for batch in tqdm(dataloader, desc="Обработка батчей"): |
|
batch = batch.to(DEVICE) |
|
test_inp = process(batch).to(DTYPE) |
|
|
|
|
|
recon_list = [] |
|
for vae in vaes: |
|
latent = vae.encode(test_inp).latent_dist.mode() |
|
dec = vae.decode(latent).sample.float() |
|
recon = deprocess(dec).clamp(0.0, 1.0) |
|
recon_list.append(recon) |
|
|
|
|
|
for recon, name in zip(recon_list, names): |
|
for i in range(batch.shape[0]): |
|
img_orig = batch[i:i+1] |
|
img_recon = recon[i:i+1] |
|
mse = F.mse_loss(img_orig, img_recon).item() |
|
psnr = 10 * torch.log10(1 / torch.tensor(mse)).item() |
|
lpips_val = lpips_net(img_orig, img_recon, normalize=True).mean().item() |
|
edge_loss = compute_edge_loss(img_orig, img_recon) |
|
results[name]["mse"] += mse |
|
results[name]["psnr"] += psnr |
|
results[name]["lpips"] += lpips_val |
|
results[name]["edge"] += edge_loss |
|
results[name]["count"] += 1 |
|
|
|
|
|
if NUM_SAMPLES_TO_SAVE > 0: |
|
for i in range(batch.shape[0]): |
|
if images_saved >= NUM_SAMPLES_TO_SAVE: |
|
break |
|
idx_str = f"{images_saved + 1:03d}" |
|
|
|
|
|
orig_pil = to_pil(batch[i].detach().float().cpu()) |
|
orig_pil.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_orig.png")) |
|
|
|
|
|
tiles = [orig_pil] |
|
for recon, name in zip(recon_list, names): |
|
recon_pil = to_pil(recon[i].detach().cpu()) |
|
recon_pil.save(os.path.join( |
|
SAMPLES_FOLDER, f"{idx_str}_decoded_{_sanitize_name(name)}.png" |
|
)) |
|
tiles.append(recon_pil) |
|
|
|
|
|
collage_w = CROP_SIZE * len(tiles) |
|
collage_h = CROP_SIZE |
|
collage = Image.new("RGB", (collage_w, collage_h)) |
|
x = 0 |
|
for tile in tiles: |
|
collage.paste(tile, (x, 0)) |
|
x += CROP_SIZE |
|
collage.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_all.png")) |
|
|
|
images_saved += 1 |
|
|
|
|
|
|
|
for name in names: |
|
count = results[name]["count"] |
|
results[name]["mse"] /= count |
|
results[name]["psnr"] /= count |
|
results[name]["lpips"] /= count |
|
results[name]["edge"] /= count |
|
|
|
|
|
print("\n=== Абсолютные значения ===") |
|
for name in names: |
|
print(f"{name:30s}: MSE: {results[name]['mse']:.3e}, PSNR: {results[name]['psnr']:.4f}, " |
|
f"LPIPS: {results[name]['lpips']:.4f}, Edge: {results[name]['edge']:.4f}") |
|
|
|
|
|
print("\n=== Сравнение с первой моделью (%) ===") |
|
print(f"| {'Модель':30s} | {'MSE':>10s} | {'PSNR':>10s} | {'LPIPS':>10s} | {'Edge':>10s} |") |
|
print(f"|{'-'*32}|{'-'*12}|{'-'*12}|{'-'*12}|{'-'*12}|") |
|
|
|
baseline = names[0] |
|
for name in names: |
|
|
|
mse_pct = (results[baseline]["mse"] / results[name]["mse"]) * 100 |
|
|
|
psnr_pct = (results[name]["psnr"] / results[baseline]["psnr"]) * 100 |
|
|
|
lpips_pct = (results[baseline]["lpips"] / results[name]["lpips"]) * 100 |
|
edge_pct = (results[baseline]["edge"] / results[name]["edge"]) * 100 |
|
|
|
if name == baseline: |
|
print(f"| {name:30s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} |") |
|
else: |
|
print(f"| {name:30s} | {f'{mse_pct:.1f}%':>10s} | {f'{psnr_pct:.1f}%':>10s} | " |
|
f"{f'{lpips_pct:.1f}%':>10s} | {f'{edge_pct:.1f}%':>10s} |") |
|
|
|
print("\n✅ Готово!") |
|
|