import argparse import multiprocessing import os import shutil from functools import partial from io import BytesIO from multiprocessing import Process, Queue from os.path import exists, join import lmdb from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision.datasets import LSUNClass from torchvision.transforms import functional as trans_fn from tqdm import tqdm def resize_and_convert(img, size, resample, quality=100): img = trans_fn.resize(img, size, resample) img = trans_fn.center_crop(img, size) buffer = BytesIO() img.save(buffer, format="webp", quality=quality) val = buffer.getvalue() return val def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100): imgs = [] for size in sizes: imgs.append(resize_and_convert(img, size, resample, quality)) return imgs def resize_worker(idx, img, sizes, resample): img = img.convert("RGB") out = resize_multiple(img, sizes=sizes, resample=resample) return idx, out class ConvertDataset(Dataset): def __init__(self, data) -> None: self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): img, _ = self.data[index] bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) return bytes if __name__ == "__main__": """ converting lsun' original lmdb to our lmdb, which is somehow more performant. """ from tqdm import tqdm # path to the original lsun's lmdb src_path = 'datasets/horse_train_lmdb' out_path = 'datasets/horse256.lmdb' dataset = LSUNClass(root=os.path.expanduser(src_path)) dataset = ConvertDataset(dataset) loader = DataLoader(dataset, batch_size=50, num_workers=16, collate_fn=lambda x: x) target = os.path.expanduser(out_path) if os.path.exists(target): shutil.rmtree(target) with lmdb.open(target, map_size=1024**4, readahead=False) as env: with tqdm(total=len(dataset)) as progress: i = 0 for batch in loader: with env.begin(write=True) as txn: for img in batch: key = f"{256}-{str(i).zfill(7)}".encode("utf-8") # print(key) txn.put(key, img) i += 1 progress.update() # if i == 1000: # break # if total == len(imgset): # break with env.begin(write=True) as txn: txn.put("length".encode("utf-8"), str(i).encode("utf-8"))