import warnings import logging import torch import torch.nn.functional as F import torch.utils.data as data import lpips from tqdm import tqdm from torchvision.transforms import ( Compose, Resize, ToTensor, CenterCrop, ) from diffusers import AutoencoderKL logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) warnings.filterwarnings( "ignore", ".*Found keys that are not in the model state dict but in the checkpoint.*", ) DEVICE = "cuda" DTYPE = torch.float16 SHORT_AXIS_SIZE = 256 NAMES = [ "madebyollin/sdxl-vae-fp16-fix", "KBlueLeaf/EQ-SDXL-VAE ", "AiArtLab/simplevae ", ] BASE_MODELS = [ "madebyollin/sdxl-vae-fp16-fix", "KBlueLeaf/EQ-SDXL-VAE", "AiArtLab/simplevae", ] SUB_FOLDERS = [None, None, "sdxl_vae"] CKPT_PATHS = [ None, None, None, ] USE_APPROXS = [False, False, False] def process(x): return x * 2 - 1 def deprocess(x): return x * 0.5 + 0.5 import torch.utils.data as data from datasets import load_dataset class ImageNetDataset(data.IterableDataset): def __init__(self, split, transform=None, max_len=10, streaming=True): self.split = split self.transform = transform self.dataset = load_dataset("evanarlian/imagenet_1k_resized_256", split=split, streaming=streaming) self.max_len = max_len self.iterator = iter(self.dataset) def __iter__(self): for i, entry in enumerate(self.iterator): if self.max_len and i >= self.max_len: break img = entry["image"] target = entry["label"] if self.transform is not None: img = self.transform(img) yield img, target if __name__ == "__main__": lpips_loss = torch.compile( lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False) ) @torch.compile def metrics(inp, recon): mse = F.mse_loss(inp, recon) psnr = 10 * torch.log10(1 / mse) return ( mse.cpu(), psnr.cpu(), lpips_loss(inp, recon, normalize=True).mean().cpu(), ) transform = Compose( [ Resize(SHORT_AXIS_SIZE), CenterCrop(SHORT_AXIS_SIZE), ToTensor(), ] ) valid_dataset = ImageNetDataset("val", transform=transform, max_len=50000, streaming=True) valid_loader = data.DataLoader( valid_dataset, batch_size=4, shuffle=False, num_workers=2, pin_memory=True, pin_memory_device=DEVICE, ) # Проверяем, что данные грузятся for batch in valid_loader: print("Batch shape:", batch[0].shape) break logger.info("Loading models...") vaes = [] for base_model, sub_folder, ckpt_path, use_approx in zip( BASE_MODELS, SUB_FOLDERS, CKPT_PATHS, USE_APPROXS ): vae = AutoencoderKL.from_pretrained(base_model, subfolder=sub_folder) if use_approx: vae.decoder = LatentApproxDecoder( latent_dim=vae.config.latent_channels, out_channels=3, shuffle=2, ) vae.decode = lambda x: vae.decoder(x) vae.get_last_layer = lambda: vae.decoder.conv_out.weight if ckpt_path: LatentTrainer.load_from_checkpoint( ckpt_path, vae=vae, map_location="cpu", strict=False ) vae = vae.to(DTYPE).eval().requires_grad_(False).to(DEVICE) vae.encoder = torch.compile(vae.encoder) vae.decoder = torch.compile(vae.decoder) vaes.append(torch.compile(vae)) logger.info("Running Validation") total = 0 all_latents = [[] for _ in range(len(vaes))] all_mse = [[] for _ in range(len(vaes))] all_psnr = [[] for _ in range(len(vaes))] all_lpips = [[] for _ in range(len(vaes))] for idx, batch in enumerate(tqdm(valid_loader)): image = batch[0].to(DEVICE) test_inp = process(image).to(DTYPE) batch_size = test_inp.size(0) for i, vae in enumerate(vaes): latent = vae.encode(test_inp).latent_dist.mode() recon = deprocess(vae.decode(latent).sample.float()) all_latents[i].append(latent.cpu().float()) mse, psnr, lpips_ = metrics(image, recon) all_mse[i].append(mse.cpu() * batch_size) all_psnr[i].append(psnr.cpu() * batch_size) all_lpips[i].append(lpips_.cpu() * batch_size) total += batch_size for i in range(len(vaes)): all_latents[i] = torch.cat(all_latents[i], dim=0) all_mse[i] = torch.stack(all_mse[i]).sum() / total all_psnr[i] = torch.stack(all_psnr[i]).sum() / total all_lpips[i] = torch.stack(all_lpips[i]).sum() / total logger.info( f" - {NAMES[i]}: MSE: {all_mse[i]:.3e}, PSNR: {all_psnr[i]:.4f}, " f"LPIPS: {all_lpips[i]:.4f}" ) logger.info("End")