# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import numbers import cv2 import numpy as np import PIL import torch from torchvision.transforms import functional as tvf def _is_tensor_clip(clip): return torch.is_tensor(clip) and clip.ndimension() == 4 def crop_clip(clip, min_h, min_w, h, w): if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor): if clip[0].shape[-1] == 3: cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip] else: assert clip[0].shape[0] == 3 cropped = [img[:, min_h : min_h + h, min_w : min_w + w] for img in clip] elif isinstance(clip[0], PIL.Image.Image): cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip] else: raise TypeError( "Expected numpy.ndarray or PIL.Image or torch.Tensor):" + "but got list of {0}".format(type(clip[0])) ) return cropped def resize_clip(clip, size, interpolation="bilinear"): if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor): if isinstance(size, numbers.Number): if clip[0].shape[-1] == 3: im_h, im_w, im_c = clip[0].shape else: assert clip[0].shape[0] == 3 im_c, im_h, im_w = clip[0].shape # Min spatial dim already matches minimal size if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[0], size[1] if isinstance(clip[0], np.ndarray): if interpolation == "bilinear": np_inter = cv2.INTER_LINEAR else: np_inter = cv2.INTER_NEAREST scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip] else: # isinstance(clip[0], torch.Tensor) if interpolation == "bilinear": np_inter = tvf.InterpolationMode.BILINEAR else: np_inter = tvf.InterpolationMode.NEAREST size = (size[1], size[0]) # torchvision transformers expect the size in (h, w) order. scaled = [tvf.resize(img, size, interpolation=np_inter) for img in clip] elif isinstance(clip[0], PIL.Image.Image): if isinstance(size, numbers.Number): im_w, im_h = clip[0].size # Min spatial dim already matches minimal size if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[1], size[0] if interpolation == "bilinear": pil_inter = PIL.Image.BILINEAR else: pil_inter = PIL.Image.NEAREST scaled = [img.resize(size, pil_inter) for img in clip] else: raise TypeError( "Expected numpy.ndarray or PIL.Image or torch.Tensor" + "but got list of {0}".format(type(clip[0])) ) return scaled def get_resize_sizes(im_h, im_w, size): if im_w < im_h: ow = size oh = int(size * im_h / im_w) else: oh = size ow = int(size * im_w / im_h) return oh, ow def normalize(clip, mean, std, inplace=False): if not _is_tensor_clip(clip): raise TypeError("tensor is not a torch clip.") if not inplace: clip = clip.clone() dtype = clip.dtype mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) std = torch.as_tensor(std, dtype=dtype, device=clip.device) clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) return clip