swim_new / scripts /compute_latent.py
qninhdt's picture
cc
82e5f44
import os
import click
import PIL
from itertools import batched
import numpy as np
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL
from tqdm import tqdm
@click.command()
@click.option("--model_name", type=str, default="stabilityai/stable-diffusion-2-1")
@click.option("--swim_dir", type=str, default="datasets/swim_data")
@click.option("--batch_size", type=int, default=1)
def compute_latent(model_name: str, swim_dir: str, batch_size: int):
model = AutoencoderKL.from_pretrained(model_name, subfolder="vae").cuda()
model.eval()
# create folder for latent vectors
os.makedirs(os.path.join(swim_dir, "train/latents"), exist_ok=True)
os.makedirs(os.path.join(swim_dir, "val/latents"), exist_ok=True)
transforms = T.Compose(
[
T.Resize(512),
T.CenterCrop(512),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
for split in ["train", "val"]:
output = os.path.join(swim_dir, split, "latents")
for image_names in tqdm(
list(
batched(os.listdir(os.path.join(swim_dir, split, "images")), batch_size)
)
):
images = [
transforms(
PIL.Image.open(os.path.join(swim_dir, split, "images", name))
)
for name in image_names
]
with torch.no_grad():
images = torch.stack(images).cuda()
latents = model.encode(images).latent_dist.mode()
latents = latents.detach().cpu().numpy()
for name, latent in zip(image_names, latents):
np.save(
os.path.join(
output, name.replace(".jpg", ".npy").replace(".png", ".npy")
),
latent,
)
if __name__ == "__main__":
compute_latent()