soccer-qa-4b / src /datasets /utils /video /transforms_builder.py
VarunKodathala's picture
Upload folder using huggingface_hub
0e37bb2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
import torchvision.transforms as transforms
import src.datasets.utils.video.transforms as video_transforms
from src.datasets.utils.video.randerase import RandomErasing
def make_transforms(
random_horizontal_flip=True,
random_resize_aspect_ratio=(3 / 4, 4 / 3),
random_resize_scale=(0.3, 1.0),
reprob=0.0,
auto_augment=False,
motion_shift=False,
crop_size=224,
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
pad_frame_count: Optional[int] = None,
pad_frame_method: str = "circulant",
):
_frames_augmentation = VideoTransform(
random_horizontal_flip=random_horizontal_flip,
random_resize_aspect_ratio=random_resize_aspect_ratio,
random_resize_scale=random_resize_scale,
reprob=reprob,
auto_augment=auto_augment,
motion_shift=motion_shift,
crop_size=crop_size,
normalize=normalize,
pad_frame_count=pad_frame_count,
pad_frame_method=pad_frame_method,
)
return _frames_augmentation
class VideoTransform(object):
def __init__(
self,
random_horizontal_flip=True,
random_resize_aspect_ratio=(3 / 4, 4 / 3),
random_resize_scale=(0.3, 1.0),
reprob=0.0,
auto_augment=False,
motion_shift=False,
crop_size=224,
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
pad_frame_count: Optional[int] = None,
pad_frame_method: str = "circulant",
):
self.random_horizontal_flip = random_horizontal_flip
self.random_resize_aspect_ratio = random_resize_aspect_ratio
self.random_resize_scale = random_resize_scale
self.auto_augment = auto_augment
self.motion_shift = motion_shift
self.crop_size = crop_size
self.mean = torch.tensor(normalize[0], dtype=torch.float32)
self.std = torch.tensor(normalize[1], dtype=torch.float32)
self.pad_frame_count = pad_frame_count
self.pad_frame_method = pad_frame_method
if not self.auto_augment:
# Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
self.mean *= 255.0
self.std *= 255.0
self.autoaug_transform = video_transforms.create_random_augment(
input_size=(crop_size, crop_size),
auto_augment="rand-m7-n4-mstd0.5-inc1",
interpolation="bicubic",
)
self.spatial_transform = (
video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
)
self.reprob = reprob
self.erase_transform = RandomErasing(
reprob,
mode="pixel",
max_count=1,
num_splits=1,
device="cpu",
)
def __call__(self, buffer):
if self.auto_augment:
buffer = [transforms.ToPILImage()(frame) for frame in buffer]
buffer = self.autoaug_transform(buffer)
buffer = [transforms.ToTensor()(img) for img in buffer]
buffer = torch.stack(buffer) # T C H W
buffer = buffer.permute(0, 2, 3, 1) # T H W C
elif torch.is_tensor(buffer):
# TODO: ensure input is always a tensor?
buffer = buffer.to(torch.float32)
else:
buffer = torch.tensor(buffer, dtype=torch.float32)
buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W
buffer = self.spatial_transform(
images=buffer,
target_height=self.crop_size,
target_width=self.crop_size,
scale=self.random_resize_scale,
ratio=self.random_resize_aspect_ratio,
)
if self.random_horizontal_flip:
buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
if self.reprob > 0:
buffer = buffer.permute(1, 0, 2, 3)
buffer = self.erase_transform(buffer)
buffer = buffer.permute(1, 0, 2, 3)
if self.pad_frame_count is not None:
buffer = video_transforms.frame_pad(buffer, self.pad_frame_count, self.pad_frame_method)
return buffer
def tensor_normalize(tensor, mean, std):
"""
Normalize a given tensor by subtracting the mean and dividing the std.
Args:
tensor (tensor): tensor to normalize.
mean (tensor or list): mean value to subtract.
std (tensor or list): std to divide.
"""
if tensor.dtype == torch.uint8:
tensor = tensor.float()
tensor = tensor / 255.0
if isinstance(mean, list):
mean = torch.tensor(mean)
if isinstance(std, list):
std = torch.tensor(std)
tensor = tensor - mean
tensor = tensor / std
return tensor
def _tensor_normalize_inplace(tensor, mean, std):
"""
Normalize a given tensor by subtracting the mean and dividing the std.
Args:
tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
mean (tensor): mean value to subtract (in 0 to 255 floats).
std (tensor): std to divide (in 0 to 255 floats).
"""
if tensor.dtype == torch.uint8:
tensor = tensor.float()
C, T, H, W = tensor.shape
tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension
tensor.sub_(mean).div_(std)
tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front
return tensor