|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import torchvision.transforms as transforms |
|
|
|
import src.datasets.utils.video.transforms as video_transforms |
|
from src.datasets.utils.video.randerase import RandomErasing |
|
|
|
|
|
def make_transforms( |
|
random_horizontal_flip=True, |
|
random_resize_aspect_ratio=(3 / 4, 4 / 3), |
|
random_resize_scale=(0.3, 1.0), |
|
reprob=0.0, |
|
auto_augment=False, |
|
motion_shift=False, |
|
crop_size=224, |
|
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
pad_frame_count: Optional[int] = None, |
|
pad_frame_method: str = "circulant", |
|
): |
|
_frames_augmentation = VideoTransform( |
|
random_horizontal_flip=random_horizontal_flip, |
|
random_resize_aspect_ratio=random_resize_aspect_ratio, |
|
random_resize_scale=random_resize_scale, |
|
reprob=reprob, |
|
auto_augment=auto_augment, |
|
motion_shift=motion_shift, |
|
crop_size=crop_size, |
|
normalize=normalize, |
|
pad_frame_count=pad_frame_count, |
|
pad_frame_method=pad_frame_method, |
|
) |
|
return _frames_augmentation |
|
|
|
|
|
class VideoTransform(object): |
|
|
|
def __init__( |
|
self, |
|
random_horizontal_flip=True, |
|
random_resize_aspect_ratio=(3 / 4, 4 / 3), |
|
random_resize_scale=(0.3, 1.0), |
|
reprob=0.0, |
|
auto_augment=False, |
|
motion_shift=False, |
|
crop_size=224, |
|
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
pad_frame_count: Optional[int] = None, |
|
pad_frame_method: str = "circulant", |
|
): |
|
self.random_horizontal_flip = random_horizontal_flip |
|
self.random_resize_aspect_ratio = random_resize_aspect_ratio |
|
self.random_resize_scale = random_resize_scale |
|
self.auto_augment = auto_augment |
|
self.motion_shift = motion_shift |
|
self.crop_size = crop_size |
|
self.mean = torch.tensor(normalize[0], dtype=torch.float32) |
|
self.std = torch.tensor(normalize[1], dtype=torch.float32) |
|
self.pad_frame_count = pad_frame_count |
|
self.pad_frame_method = pad_frame_method |
|
|
|
if not self.auto_augment: |
|
|
|
self.mean *= 255.0 |
|
self.std *= 255.0 |
|
|
|
self.autoaug_transform = video_transforms.create_random_augment( |
|
input_size=(crop_size, crop_size), |
|
auto_augment="rand-m7-n4-mstd0.5-inc1", |
|
interpolation="bicubic", |
|
) |
|
|
|
self.spatial_transform = ( |
|
video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop |
|
) |
|
|
|
self.reprob = reprob |
|
self.erase_transform = RandomErasing( |
|
reprob, |
|
mode="pixel", |
|
max_count=1, |
|
num_splits=1, |
|
device="cpu", |
|
) |
|
|
|
def __call__(self, buffer): |
|
|
|
if self.auto_augment: |
|
buffer = [transforms.ToPILImage()(frame) for frame in buffer] |
|
buffer = self.autoaug_transform(buffer) |
|
buffer = [transforms.ToTensor()(img) for img in buffer] |
|
buffer = torch.stack(buffer) |
|
buffer = buffer.permute(0, 2, 3, 1) |
|
elif torch.is_tensor(buffer): |
|
|
|
buffer = buffer.to(torch.float32) |
|
else: |
|
buffer = torch.tensor(buffer, dtype=torch.float32) |
|
|
|
buffer = buffer.permute(3, 0, 1, 2) |
|
|
|
buffer = self.spatial_transform( |
|
images=buffer, |
|
target_height=self.crop_size, |
|
target_width=self.crop_size, |
|
scale=self.random_resize_scale, |
|
ratio=self.random_resize_aspect_ratio, |
|
) |
|
if self.random_horizontal_flip: |
|
buffer, _ = video_transforms.horizontal_flip(0.5, buffer) |
|
|
|
buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) |
|
if self.reprob > 0: |
|
buffer = buffer.permute(1, 0, 2, 3) |
|
buffer = self.erase_transform(buffer) |
|
buffer = buffer.permute(1, 0, 2, 3) |
|
|
|
if self.pad_frame_count is not None: |
|
buffer = video_transforms.frame_pad(buffer, self.pad_frame_count, self.pad_frame_method) |
|
|
|
return buffer |
|
|
|
|
|
def tensor_normalize(tensor, mean, std): |
|
""" |
|
Normalize a given tensor by subtracting the mean and dividing the std. |
|
Args: |
|
tensor (tensor): tensor to normalize. |
|
mean (tensor or list): mean value to subtract. |
|
std (tensor or list): std to divide. |
|
""" |
|
if tensor.dtype == torch.uint8: |
|
tensor = tensor.float() |
|
tensor = tensor / 255.0 |
|
if isinstance(mean, list): |
|
mean = torch.tensor(mean) |
|
if isinstance(std, list): |
|
std = torch.tensor(std) |
|
tensor = tensor - mean |
|
tensor = tensor / std |
|
return tensor |
|
|
|
|
|
def _tensor_normalize_inplace(tensor, mean, std): |
|
""" |
|
Normalize a given tensor by subtracting the mean and dividing the std. |
|
Args: |
|
tensor (tensor): tensor to normalize (with dimensions C, T, H, W). |
|
mean (tensor): mean value to subtract (in 0 to 255 floats). |
|
std (tensor): std to divide (in 0 to 255 floats). |
|
""" |
|
if tensor.dtype == torch.uint8: |
|
tensor = tensor.float() |
|
|
|
C, T, H, W = tensor.shape |
|
tensor = tensor.view(C, -1).permute(1, 0) |
|
tensor.sub_(mean).div_(std) |
|
tensor = tensor.permute(1, 0).view(C, T, H, W) |
|
return tensor |
|
|