|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
def convert_img(img): |
|
"""Converts (H, W, C) numpy.ndarray to (C, W, H) format""" |
|
if len(img.shape) == 3: |
|
img = img.transpose(2, 0, 1) |
|
if len(img.shape) == 2: |
|
img = np.expand_dims(img, 0) |
|
return img |
|
|
|
|
|
class ClipToTensor(object): |
|
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] |
|
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] |
|
""" |
|
|
|
def __init__(self, channel_nb=3, div_255=True, numpy=False): |
|
self.channel_nb = channel_nb |
|
self.div_255 = div_255 |
|
self.numpy = numpy |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: clip (list of numpy.ndarray): clip (list of images) |
|
to be converted to tensor. |
|
""" |
|
|
|
if isinstance(clip[0], np.ndarray): |
|
h, w, ch = clip[0].shape |
|
assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) |
|
elif isinstance(clip[0], Image.Image): |
|
w, h = clip[0].size |
|
elif isinstance(clip[0], torch.Tensor): |
|
tensor_clip = torch.stack(clip) |
|
|
|
|
|
tensor_clip = tensor_clip.permute(1, 0, 2, 3) |
|
if not isinstance(tensor_clip, torch.FloatTensor): |
|
tensor_clip = tensor_clip.float() |
|
if self.div_255: |
|
tensor_clip = torch.div(tensor_clip, 255) |
|
return tensor_clip |
|
else: |
|
raise TypeError( |
|
"Expected numpy.ndarray or PIL.Image or torch.Tensor\ |
|
but got list of {0}".format( |
|
type(clip[0]) |
|
) |
|
) |
|
|
|
np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) |
|
|
|
|
|
for img_idx, img in enumerate(clip): |
|
if isinstance(img, np.ndarray): |
|
pass |
|
elif isinstance(img, Image.Image): |
|
img = np.array(img, copy=False) |
|
else: |
|
raise TypeError( |
|
"Expected numpy.ndarray or PIL.Image\ |
|
but got list of {0}".format( |
|
type(clip[0]) |
|
) |
|
) |
|
img = convert_img(img) |
|
np_clip[:, img_idx, :, :] = img |
|
|
|
if self.numpy: |
|
if self.div_255: |
|
np_clip = np_clip / 255.0 |
|
return np_clip |
|
|
|
else: |
|
tensor_clip = torch.from_numpy(np_clip) |
|
|
|
if not isinstance(tensor_clip, torch.FloatTensor): |
|
tensor_clip = tensor_clip.float() |
|
if self.div_255: |
|
tensor_clip = torch.div(tensor_clip, 255) |
|
return tensor_clip |
|
|
|
|
|
|
|
class ClipToTensor_K(object): |
|
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] |
|
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] |
|
""" |
|
|
|
def __init__(self, channel_nb=3, div_255=True, numpy=False): |
|
self.channel_nb = channel_nb |
|
self.div_255 = div_255 |
|
self.numpy = numpy |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: clip (list of numpy.ndarray): clip (list of images) |
|
to be converted to tensor. |
|
""" |
|
|
|
if isinstance(clip[0], np.ndarray): |
|
h, w, ch = clip[0].shape |
|
assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) |
|
elif isinstance(clip[0], Image.Image): |
|
w, h = clip[0].size |
|
else: |
|
raise TypeError( |
|
"Expected numpy.ndarray or PIL.Image\ |
|
but got list of {0}".format( |
|
type(clip[0]) |
|
) |
|
) |
|
|
|
np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) |
|
|
|
|
|
for img_idx, img in enumerate(clip): |
|
if isinstance(img, np.ndarray): |
|
pass |
|
elif isinstance(img, Image.Image): |
|
img = np.array(img, copy=False) |
|
else: |
|
raise TypeError( |
|
"Expected numpy.ndarray or PIL.Image\ |
|
but got list of {0}".format( |
|
type(clip[0]) |
|
) |
|
) |
|
img = convert_img(img) |
|
np_clip[:, img_idx, :, :] = img |
|
if self.numpy: |
|
if self.div_255: |
|
np_clip = (np_clip - 127.5) / 127.5 |
|
return np_clip |
|
|
|
else: |
|
tensor_clip = torch.from_numpy(np_clip) |
|
|
|
if not isinstance(tensor_clip, torch.FloatTensor): |
|
tensor_clip = tensor_clip.float() |
|
if self.div_255: |
|
tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) |
|
return tensor_clip |
|
|
|
|
|
class ToTensor(object): |
|
"""Converts numpy array to tensor""" |
|
|
|
def __call__(self, array): |
|
tensor = torch.from_numpy(array) |
|
return tensor |
|
|