|
from io import BytesIO |
|
|
|
import lmdb |
|
from PIL import Image |
|
|
|
import torch |
|
|
|
from contextlib import contextmanager |
|
from torch.utils.data import Dataset |
|
from multiprocessing import Process, Queue |
|
import os |
|
import shutil |
|
|
|
|
|
def convert(x, format, quality=100): |
|
|
|
torch.set_num_threads(1) |
|
|
|
buffer = BytesIO() |
|
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0) |
|
x = x.to(torch.uint8) |
|
x = x.numpy() |
|
img = Image.fromarray(x) |
|
img.save(buffer, format=format, quality=quality) |
|
val = buffer.getvalue() |
|
return val |
|
|
|
|
|
@contextmanager |
|
def nullcontext(): |
|
yield |
|
|
|
|
|
class _WriterWroker(Process): |
|
def __init__(self, path, format, quality, zfill, q): |
|
super().__init__() |
|
if os.path.exists(path): |
|
shutil.rmtree(path) |
|
|
|
self.path = path |
|
self.format = format |
|
self.quality = quality |
|
self.zfill = zfill |
|
self.q = q |
|
self.i = 0 |
|
|
|
def run(self): |
|
if not os.path.exists(self.path): |
|
os.makedirs(self.path) |
|
|
|
with lmdb.open(self.path, map_size=1024**4, readahead=False) as env: |
|
while True: |
|
job = self.q.get() |
|
if job is None: |
|
break |
|
with env.begin(write=True) as txn: |
|
for x in job: |
|
key = f"{str(self.i).zfill(self.zfill)}".encode( |
|
"utf-8") |
|
x = convert(x, self.format, self.quality) |
|
txn.put(key, x) |
|
self.i += 1 |
|
|
|
with env.begin(write=True) as txn: |
|
txn.put("length".encode("utf-8"), str(self.i).encode("utf-8")) |
|
|
|
|
|
class LMDBImageWriter: |
|
def __init__(self, path, format='webp', quality=100, zfill=7) -> None: |
|
self.path = path |
|
self.format = format |
|
self.quality = quality |
|
self.zfill = zfill |
|
self.queue = None |
|
self.worker = None |
|
|
|
def __enter__(self): |
|
self.queue = Queue(maxsize=3) |
|
self.worker = _WriterWroker(self.path, self.format, self.quality, |
|
self.zfill, self.queue) |
|
self.worker.start() |
|
|
|
def put_images(self, tensor): |
|
""" |
|
Args: |
|
tensor: (n, c, h, w) [0-1] tensor |
|
""" |
|
self.queue.put(tensor.cpu()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, *args, **kwargs): |
|
self.queue.put(None) |
|
self.queue.close() |
|
self.worker.join() |
|
|
|
|
|
class LMDBImageReader(Dataset): |
|
def __init__(self, path, zfill: int = 7): |
|
self.zfill = zfill |
|
self.env = lmdb.open( |
|
path, |
|
max_readers=32, |
|
readonly=True, |
|
lock=False, |
|
readahead=False, |
|
meminit=False, |
|
) |
|
|
|
if not self.env: |
|
raise IOError('Cannot open lmdb dataset', path) |
|
|
|
with self.env.begin(write=False) as txn: |
|
self.length = int( |
|
txn.get('length'.encode('utf-8')).decode('utf-8')) |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, index): |
|
with self.env.begin(write=False) as txn: |
|
key = f'{str(index).zfill(self.zfill)}'.encode('utf-8') |
|
img_bytes = txn.get(key) |
|
|
|
buffer = BytesIO(img_bytes) |
|
img = Image.open(buffer) |
|
return img |
|
|