|  | import os | 
					
						
						|  | import torch | 
					
						
						|  | import numpy as np | 
					
						
						|  | from tqdm import trange | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_state(gpu): | 
					
						
						|  | import torch | 
					
						
						|  | midas = torch.hub.load("intel-isl/MiDaS", "MiDaS") | 
					
						
						|  | if gpu: | 
					
						
						|  | midas.cuda() | 
					
						
						|  | midas.eval() | 
					
						
						|  |  | 
					
						
						|  | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") | 
					
						
						|  | transform = midas_transforms.default_transform | 
					
						
						|  |  | 
					
						
						|  | state = {"model": midas, | 
					
						
						|  | "transform": transform} | 
					
						
						|  | return state | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def depth_to_rgba(x): | 
					
						
						|  | assert x.dtype == np.float32 | 
					
						
						|  | assert len(x.shape) == 2 | 
					
						
						|  | y = x.copy() | 
					
						
						|  | y.dtype = np.uint8 | 
					
						
						|  | y = y.reshape(x.shape+(4,)) | 
					
						
						|  | return np.ascontiguousarray(y) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def rgba_to_depth(x): | 
					
						
						|  | assert x.dtype == np.uint8 | 
					
						
						|  | assert len(x.shape) == 3 and x.shape[2] == 4 | 
					
						
						|  | y = x.copy() | 
					
						
						|  | y.dtype = np.float32 | 
					
						
						|  | y = y.reshape(x.shape[:2]) | 
					
						
						|  | return np.ascontiguousarray(y) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run(x, state): | 
					
						
						|  | model = state["model"] | 
					
						
						|  | transform = state["transform"] | 
					
						
						|  | hw = x.shape[:2] | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | prediction = model(transform((x + 1.0) * 127.5).cuda()) | 
					
						
						|  | prediction = torch.nn.functional.interpolate( | 
					
						
						|  | prediction.unsqueeze(1), | 
					
						
						|  | size=hw, | 
					
						
						|  | mode="bicubic", | 
					
						
						|  | align_corners=False, | 
					
						
						|  | ).squeeze() | 
					
						
						|  | output = prediction.cpu().numpy() | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_filename(relpath, level=-2): | 
					
						
						|  |  | 
					
						
						|  | fn = relpath.split(os.sep)[level:] | 
					
						
						|  | folder = fn[-2] | 
					
						
						|  | file   = fn[-1].split('.')[0] | 
					
						
						|  | return folder, file | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def save_depth(dataset, path, debug=False): | 
					
						
						|  | os.makedirs(path) | 
					
						
						|  | N = len(dset) | 
					
						
						|  | if debug: | 
					
						
						|  | N = 10 | 
					
						
						|  | state = get_state(gpu=True) | 
					
						
						|  | for idx in trange(N, desc="Data"): | 
					
						
						|  | ex = dataset[idx] | 
					
						
						|  | image, relpath = ex["image"], ex["relpath"] | 
					
						
						|  | folder, filename = get_filename(relpath) | 
					
						
						|  |  | 
					
						
						|  | folderabspath = os.path.join(path, folder) | 
					
						
						|  | os.makedirs(folderabspath, exist_ok=True) | 
					
						
						|  | savepath = os.path.join(folderabspath, filename) | 
					
						
						|  |  | 
					
						
						|  | xout = run(image, state) | 
					
						
						|  | I = depth_to_rgba(xout) | 
					
						
						|  | Image.fromarray(I).save("{}.png".format(savepath)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | from taming.data.imagenet import ImageNetTrain, ImageNetValidation | 
					
						
						|  | out = "data/imagenet_depth" | 
					
						
						|  | if not os.path.exists(out): | 
					
						
						|  | print("Please create a folder or symlink '{}' to extract depth data ".format(out) + | 
					
						
						|  | "(be prepared that the output size will be larger than ImageNet itself).") | 
					
						
						|  | exit(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dset = ImageNetValidation() | 
					
						
						|  | abspath = os.path.join(out, "val") | 
					
						
						|  | if os.path.exists(abspath): | 
					
						
						|  | print("{} exists - not doing anything.".format(abspath)) | 
					
						
						|  | else: | 
					
						
						|  | print("preparing {}".format(abspath)) | 
					
						
						|  | save_depth(dset, abspath) | 
					
						
						|  | print("done with validation split") | 
					
						
						|  |  | 
					
						
						|  | dset = ImageNetTrain() | 
					
						
						|  | abspath = os.path.join(out, "train") | 
					
						
						|  | if os.path.exists(abspath): | 
					
						
						|  | print("{} exists - not doing anything.".format(abspath)) | 
					
						
						|  | else: | 
					
						
						|  | print("preparing {}".format(abspath)) | 
					
						
						|  | save_depth(dset, abspath) | 
					
						
						|  | print("done with train split") | 
					
						
						|  |  | 
					
						
						|  | print("done done.") | 
					
						
						|  |  |