File size: 5,004 Bytes
2a06cb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import warnings
import logging
import torch
import torch.nn.functional as F
import torch.utils.data as data
import lpips
from tqdm import tqdm
from torchvision.transforms import (
Compose,
Resize,
ToTensor,
CenterCrop,
)
from diffusers import AutoencoderKL
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
warnings.filterwarnings(
"ignore",
".*Found keys that are not in the model state dict but in the checkpoint.*",
)
DEVICE = "cuda"
DTYPE = torch.float16
SHORT_AXIS_SIZE = 256
NAMES = [
"madebyollin/sdxl-vae-fp16-fix",
"KBlueLeaf/EQ-SDXL-VAE ",
"AiArtLab/simplevae ",
]
BASE_MODELS = [
"madebyollin/sdxl-vae-fp16-fix",
"KBlueLeaf/EQ-SDXL-VAE",
"AiArtLab/simplevae",
]
SUB_FOLDERS = [None, None, "sdxl_vae"]
CKPT_PATHS = [
None,
None,
None,
]
USE_APPROXS = [False, False, False]
def process(x):
return x * 2 - 1
def deprocess(x):
return x * 0.5 + 0.5
import torch.utils.data as data
from datasets import load_dataset
class ImageNetDataset(data.IterableDataset):
def __init__(self, split, transform=None, max_len=10, streaming=True):
self.split = split
self.transform = transform
self.dataset = load_dataset("evanarlian/imagenet_1k_resized_256", split=split, streaming=streaming)
self.max_len = max_len
self.iterator = iter(self.dataset)
def __iter__(self):
for i, entry in enumerate(self.iterator):
if self.max_len and i >= self.max_len:
break
img = entry["image"]
target = entry["label"]
if self.transform is not None:
img = self.transform(img)
yield img, target
if __name__ == "__main__":
lpips_loss = torch.compile(
lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False)
)
@torch.compile
def metrics(inp, recon):
mse = F.mse_loss(inp, recon)
psnr = 10 * torch.log10(1 / mse)
return (
mse.cpu(),
psnr.cpu(),
lpips_loss(inp, recon, normalize=True).mean().cpu(),
)
transform = Compose(
[
Resize(SHORT_AXIS_SIZE),
CenterCrop(SHORT_AXIS_SIZE),
ToTensor(),
]
)
valid_dataset = ImageNetDataset("val", transform=transform, max_len=50000, streaming=True)
valid_loader = data.DataLoader(
valid_dataset,
batch_size=4,
shuffle=False,
num_workers=2,
pin_memory=True,
pin_memory_device=DEVICE,
)
# Проверяем, что данные грузятся
for batch in valid_loader:
print("Batch shape:", batch[0].shape)
break
logger.info("Loading models...")
vaes = []
for base_model, sub_folder, ckpt_path, use_approx in zip(
BASE_MODELS, SUB_FOLDERS, CKPT_PATHS, USE_APPROXS
):
vae = AutoencoderKL.from_pretrained(base_model, subfolder=sub_folder)
if use_approx:
vae.decoder = LatentApproxDecoder(
latent_dim=vae.config.latent_channels,
out_channels=3,
shuffle=2,
)
vae.decode = lambda x: vae.decoder(x)
vae.get_last_layer = lambda: vae.decoder.conv_out.weight
if ckpt_path:
LatentTrainer.load_from_checkpoint(
ckpt_path, vae=vae, map_location="cpu", strict=False
)
vae = vae.to(DTYPE).eval().requires_grad_(False).to(DEVICE)
vae.encoder = torch.compile(vae.encoder)
vae.decoder = torch.compile(vae.decoder)
vaes.append(torch.compile(vae))
logger.info("Running Validation")
total = 0
all_latents = [[] for _ in range(len(vaes))]
all_mse = [[] for _ in range(len(vaes))]
all_psnr = [[] for _ in range(len(vaes))]
all_lpips = [[] for _ in range(len(vaes))]
for idx, batch in enumerate(tqdm(valid_loader)):
image = batch[0].to(DEVICE)
test_inp = process(image).to(DTYPE)
batch_size = test_inp.size(0)
for i, vae in enumerate(vaes):
latent = vae.encode(test_inp).latent_dist.mode()
recon = deprocess(vae.decode(latent).sample.float())
all_latents[i].append(latent.cpu().float())
mse, psnr, lpips_ = metrics(image, recon)
all_mse[i].append(mse.cpu() * batch_size)
all_psnr[i].append(psnr.cpu() * batch_size)
all_lpips[i].append(lpips_.cpu() * batch_size)
total += batch_size
for i in range(len(vaes)):
all_latents[i] = torch.cat(all_latents[i], dim=0)
all_mse[i] = torch.stack(all_mse[i]).sum() / total
all_psnr[i] = torch.stack(all_psnr[i]).sum() / total
all_lpips[i] = torch.stack(all_lpips[i]).sum() / total
logger.info(
f" - {NAMES[i]}: MSE: {all_mse[i]:.3e}, PSNR: {all_psnr[i]:.4f}, "
f"LPIPS: {all_lpips[i]:.4f}"
)
logger.info("End") |