File size: 3,189 Bytes
1ab03a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import multiprocessing
from functools import partial
from io import BytesIO
from pathlib import Path

import lmdb
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import functional as trans_fn
from tqdm import tqdm
import os


def resize_and_convert(img, size, resample, quality=100):
    img = trans_fn.resize(img, size, resample)
    img = trans_fn.center_crop(img, size)
    buffer = BytesIO()
    img.save(buffer, format="jpeg", quality=quality)
    val = buffer.getvalue()

    return val


def resize_multiple(img,
                    sizes=(128, 256, 512, 1024),
                    resample=Image.LANCZOS,
                    quality=100):
    imgs = []

    for size in sizes:
        imgs.append(resize_and_convert(img, size, resample, quality))

    return imgs


def resize_worker(img_file, sizes, resample):
    i, (file, idx) = img_file
    img = Image.open(file)
    img = img.convert("RGB")
    out = resize_multiple(img, sizes=sizes, resample=resample)

    return i, idx, out


def prepare(env,
            paths,
            n_worker,
            sizes=(128, 256, 512, 1024),
            resample=Image.LANCZOS):
    resize_fn = partial(resize_worker, sizes=sizes, resample=resample)

    # index = filename in int
    indexs = []
    for each in paths:
        file = os.path.basename(each)
        name, ext = file.split('.')
        idx = int(name)
        indexs.append(idx)

    # sort by file index
    files = sorted(zip(paths, indexs), key=lambda x: x[1])
    files = list(enumerate(files))
    total = 0

    with multiprocessing.Pool(n_worker) as pool:
        for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
            for size, img in zip(sizes, imgs):
                key = f"{size}-{str(idx).zfill(5)}".encode("utf-8")

                with env.begin(write=True) as txn:
                    txn.put(key, img)

            total += 1

        with env.begin(write=True) as txn:
            txn.put("length".encode("utf-8"), str(total).encode("utf-8"))


class ImageFolder(Dataset):
    def __init__(self, folder, exts=['jpg']):
        super().__init__()
        self.paths = [
            p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')
        ]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = os.path.join(self.folder, self.paths[index])
        img = Image.open(path)
        return img


if __name__ == "__main__":
    """
    converting celebahq images to lmdb
    """
    num_workers = 16
    in_path = 'datasets/celebahq'
    out_path = 'datasets/celebahq256.lmdb'

    resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
    resample = resample_map['lanczos']

    sizes = [256]

    print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))

    # imgset = datasets.ImageFolder(in_path)
    # imgset = ImageFolder(in_path)
    exts = ['jpg']
    paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')]

    with lmdb.open(out_path, map_size=1024**4, readahead=False) as env:
        prepare(env, paths, num_workers, sizes=sizes, resample=resample)