File size: 3,613 Bytes
1ab03a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from io import BytesIO
import lmdb
from PIL import Image
import torch
from contextlib import contextmanager
from torch.utils.data import Dataset
from multiprocessing import Process, Queue
import os
import shutil
def convert(x, format, quality=100):
# to prevent locking!
torch.set_num_threads(1)
buffer = BytesIO()
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
x = x.to(torch.uint8)
x = x.numpy()
img = Image.fromarray(x)
img.save(buffer, format=format, quality=quality)
val = buffer.getvalue()
return val
@contextmanager
def nullcontext():
yield
class _WriterWroker(Process):
def __init__(self, path, format, quality, zfill, q):
super().__init__()
if os.path.exists(path):
shutil.rmtree(path)
self.path = path
self.format = format
self.quality = quality
self.zfill = zfill
self.q = q
self.i = 0
def run(self):
if not os.path.exists(self.path):
os.makedirs(self.path)
with lmdb.open(self.path, map_size=1024**4, readahead=False) as env:
while True:
job = self.q.get()
if job is None:
break
with env.begin(write=True) as txn:
for x in job:
key = f"{str(self.i).zfill(self.zfill)}".encode(
"utf-8")
x = convert(x, self.format, self.quality)
txn.put(key, x)
self.i += 1
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(self.i).encode("utf-8"))
class LMDBImageWriter:
def __init__(self, path, format='webp', quality=100, zfill=7) -> None:
self.path = path
self.format = format
self.quality = quality
self.zfill = zfill
self.queue = None
self.worker = None
def __enter__(self):
self.queue = Queue(maxsize=3)
self.worker = _WriterWroker(self.path, self.format, self.quality,
self.zfill, self.queue)
self.worker.start()
def put_images(self, tensor):
"""
Args:
tensor: (n, c, h, w) [0-1] tensor
"""
self.queue.put(tensor.cpu())
# with self.env.begin(write=True) as txn:
# for x in tensor:
# key = f"{str(self.i).zfill(self.zfill)}".encode("utf-8")
# x = convert(x, self.format, self.quality)
# txn.put(key, x)
# self.i += 1
def __exit__(self, *args, **kwargs):
self.queue.put(None)
self.queue.close()
self.worker.join()
class LMDBImageReader(Dataset):
def __init__(self, path, zfill: int = 7):
self.zfill = zfill
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
if not self.env:
raise IOError('Cannot open lmdb dataset', path)
with self.env.begin(write=False) as txn:
self.length = int(
txn.get('length'.encode('utf-8')).decode('utf-8'))
def __len__(self):
return self.length
def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = f'{str(index).zfill(self.zfill)}'.encode('utf-8')
img_bytes = txn.get(key)
buffer = BytesIO(img_bytes)
img = Image.open(buffer)
return img
|