diffae / lmdb_writer.py
qninhdt's picture
Upload 68 files
1ab03a3 verified
raw
history blame
3.61 kB
from io import BytesIO
import lmdb
from PIL import Image
import torch
from contextlib import contextmanager
from torch.utils.data import Dataset
from multiprocessing import Process, Queue
import os
import shutil
def convert(x, format, quality=100):
# to prevent locking!
torch.set_num_threads(1)
buffer = BytesIO()
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
x = x.to(torch.uint8)
x = x.numpy()
img = Image.fromarray(x)
img.save(buffer, format=format, quality=quality)
val = buffer.getvalue()
return val
@contextmanager
def nullcontext():
yield
class _WriterWroker(Process):
def __init__(self, path, format, quality, zfill, q):
super().__init__()
if os.path.exists(path):
shutil.rmtree(path)
self.path = path
self.format = format
self.quality = quality
self.zfill = zfill
self.q = q
self.i = 0
def run(self):
if not os.path.exists(self.path):
os.makedirs(self.path)
with lmdb.open(self.path, map_size=1024**4, readahead=False) as env:
while True:
job = self.q.get()
if job is None:
break
with env.begin(write=True) as txn:
for x in job:
key = f"{str(self.i).zfill(self.zfill)}".encode(
"utf-8")
x = convert(x, self.format, self.quality)
txn.put(key, x)
self.i += 1
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(self.i).encode("utf-8"))
class LMDBImageWriter:
def __init__(self, path, format='webp', quality=100, zfill=7) -> None:
self.path = path
self.format = format
self.quality = quality
self.zfill = zfill
self.queue = None
self.worker = None
def __enter__(self):
self.queue = Queue(maxsize=3)
self.worker = _WriterWroker(self.path, self.format, self.quality,
self.zfill, self.queue)
self.worker.start()
def put_images(self, tensor):
"""
Args:
tensor: (n, c, h, w) [0-1] tensor
"""
self.queue.put(tensor.cpu())
# with self.env.begin(write=True) as txn:
# for x in tensor:
# key = f"{str(self.i).zfill(self.zfill)}".encode("utf-8")
# x = convert(x, self.format, self.quality)
# txn.put(key, x)
# self.i += 1
def __exit__(self, *args, **kwargs):
self.queue.put(None)
self.queue.close()
self.worker.join()
class LMDBImageReader(Dataset):
def __init__(self, path, zfill: int = 7):
self.zfill = zfill
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
if not self.env:
raise IOError('Cannot open lmdb dataset', path)
with self.env.begin(write=False) as txn:
self.length = int(
txn.get('length'.encode('utf-8')).decode('utf-8'))
def __len__(self):
return self.length
def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = f'{str(index).zfill(self.zfill)}'.encode('utf-8')
img_bytes = txn.get(key)
buffer = BytesIO(img_bytes)
img = Image.open(buffer)
return img