import os import click import PIL from itertools import batched import numpy as np import torch import torchvision.transforms as T from diffusers import AutoencoderKL from tqdm import tqdm @click.command() @click.option("--model_name", type=str, default="stabilityai/stable-diffusion-2-1") @click.option("--swim_dir", type=str, default="datasets/swim_data") @click.option("--batch_size", type=int, default=1) def compute_latent(model_name: str, swim_dir: str, batch_size: int): model = AutoencoderKL.from_pretrained(model_name, subfolder="vae").cuda() model.eval() # create folder for latent vectors os.makedirs(os.path.join(swim_dir, "train/latents"), exist_ok=True) os.makedirs(os.path.join(swim_dir, "val/latents"), exist_ok=True) transforms = T.Compose( [ T.Resize(512), T.CenterCrop(512), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) for split in ["train", "val"]: output = os.path.join(swim_dir, split, "latents") for image_names in tqdm( list( batched(os.listdir(os.path.join(swim_dir, split, "images")), batch_size) ) ): images = [ transforms( PIL.Image.open(os.path.join(swim_dir, split, "images", name)) ) for name in image_names ] with torch.no_grad(): images = torch.stack(images).cuda() latents = model.encode(images).latent_dist.mode() latents = latents.detach().cpu().numpy() for name, latent in zip(image_names, latents): np.save( os.path.join( output, name.replace(".jpg", ".npy").replace(".png", ".npy") ), latent, ) if __name__ == "__main__": compute_latent()