|
import warnings |
|
import logging |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.data as data |
|
import lpips |
|
from tqdm import tqdm |
|
from torchvision.transforms import ( |
|
Compose, |
|
Resize, |
|
ToTensor, |
|
CenterCrop, |
|
) |
|
from diffusers import AutoencoderKL |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
warnings.filterwarnings( |
|
"ignore", |
|
".*Found keys that are not in the model state dict but in the checkpoint.*", |
|
) |
|
|
|
DEVICE = "cuda" |
|
DTYPE = torch.float16 |
|
SHORT_AXIS_SIZE = 256 |
|
|
|
NAMES = [ |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
"KBlueLeaf/EQ-SDXL-VAE ", |
|
"AiArtLab/simplevae ", |
|
] |
|
BASE_MODELS = [ |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
"KBlueLeaf/EQ-SDXL-VAE", |
|
"AiArtLab/simplevae", |
|
] |
|
SUB_FOLDERS = [None, None, "sdxl_vae"] |
|
CKPT_PATHS = [ |
|
None, |
|
None, |
|
None, |
|
] |
|
USE_APPROXS = [False, False, False] |
|
|
|
def process(x): |
|
return x * 2 - 1 |
|
|
|
def deprocess(x): |
|
return x * 0.5 + 0.5 |
|
|
|
import torch.utils.data as data |
|
from datasets import load_dataset |
|
|
|
class ImageNetDataset(data.IterableDataset): |
|
def __init__(self, split, transform=None, max_len=10, streaming=True): |
|
self.split = split |
|
self.transform = transform |
|
self.dataset = load_dataset("evanarlian/imagenet_1k_resized_256", split=split, streaming=streaming) |
|
self.max_len = max_len |
|
self.iterator = iter(self.dataset) |
|
|
|
def __iter__(self): |
|
for i, entry in enumerate(self.iterator): |
|
if self.max_len and i >= self.max_len: |
|
break |
|
img = entry["image"] |
|
target = entry["label"] |
|
if self.transform is not None: |
|
img = self.transform(img) |
|
yield img, target |
|
|
|
if __name__ == "__main__": |
|
lpips_loss = torch.compile( |
|
lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False) |
|
) |
|
|
|
@torch.compile |
|
def metrics(inp, recon): |
|
mse = F.mse_loss(inp, recon) |
|
psnr = 10 * torch.log10(1 / mse) |
|
return ( |
|
mse.cpu(), |
|
psnr.cpu(), |
|
lpips_loss(inp, recon, normalize=True).mean().cpu(), |
|
) |
|
|
|
transform = Compose( |
|
[ |
|
Resize(SHORT_AXIS_SIZE), |
|
CenterCrop(SHORT_AXIS_SIZE), |
|
ToTensor(), |
|
] |
|
) |
|
valid_dataset = ImageNetDataset("val", transform=transform, max_len=50000, streaming=True) |
|
valid_loader = data.DataLoader( |
|
valid_dataset, |
|
batch_size=4, |
|
shuffle=False, |
|
num_workers=2, |
|
pin_memory=True, |
|
pin_memory_device=DEVICE, |
|
) |
|
|
|
|
|
for batch in valid_loader: |
|
print("Batch shape:", batch[0].shape) |
|
break |
|
|
|
logger.info("Loading models...") |
|
vaes = [] |
|
for base_model, sub_folder, ckpt_path, use_approx in zip( |
|
BASE_MODELS, SUB_FOLDERS, CKPT_PATHS, USE_APPROXS |
|
): |
|
vae = AutoencoderKL.from_pretrained(base_model, subfolder=sub_folder) |
|
if use_approx: |
|
vae.decoder = LatentApproxDecoder( |
|
latent_dim=vae.config.latent_channels, |
|
out_channels=3, |
|
shuffle=2, |
|
) |
|
vae.decode = lambda x: vae.decoder(x) |
|
vae.get_last_layer = lambda: vae.decoder.conv_out.weight |
|
if ckpt_path: |
|
LatentTrainer.load_from_checkpoint( |
|
ckpt_path, vae=vae, map_location="cpu", strict=False |
|
) |
|
vae = vae.to(DTYPE).eval().requires_grad_(False).to(DEVICE) |
|
vae.encoder = torch.compile(vae.encoder) |
|
vae.decoder = torch.compile(vae.decoder) |
|
vaes.append(torch.compile(vae)) |
|
|
|
logger.info("Running Validation") |
|
total = 0 |
|
all_latents = [[] for _ in range(len(vaes))] |
|
all_mse = [[] for _ in range(len(vaes))] |
|
all_psnr = [[] for _ in range(len(vaes))] |
|
all_lpips = [[] for _ in range(len(vaes))] |
|
|
|
for idx, batch in enumerate(tqdm(valid_loader)): |
|
image = batch[0].to(DEVICE) |
|
test_inp = process(image).to(DTYPE) |
|
batch_size = test_inp.size(0) |
|
|
|
for i, vae in enumerate(vaes): |
|
latent = vae.encode(test_inp).latent_dist.mode() |
|
recon = deprocess(vae.decode(latent).sample.float()) |
|
all_latents[i].append(latent.cpu().float()) |
|
mse, psnr, lpips_ = metrics(image, recon) |
|
all_mse[i].append(mse.cpu() * batch_size) |
|
all_psnr[i].append(psnr.cpu() * batch_size) |
|
all_lpips[i].append(lpips_.cpu() * batch_size) |
|
|
|
total += batch_size |
|
|
|
for i in range(len(vaes)): |
|
all_latents[i] = torch.cat(all_latents[i], dim=0) |
|
all_mse[i] = torch.stack(all_mse[i]).sum() / total |
|
all_psnr[i] = torch.stack(all_psnr[i]).sum() / total |
|
all_lpips[i] = torch.stack(all_lpips[i]).sum() / total |
|
|
|
logger.info( |
|
f" - {NAMES[i]}: MSE: {all_mse[i]:.3e}, PSNR: {all_psnr[i]:.4f}, " |
|
f"LPIPS: {all_lpips[i]:.4f}" |
|
) |
|
|
|
logger.info("End") |