soccer-qa-4b / src /masks /multiseq_multiblock3d.py
VarunKodathala's picture
Upload folder using huggingface_hub
0e37bb2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from logging import getLogger
from multiprocessing import Value
import torch
_GLOBAL_SEED = 0
logger = getLogger()
class MaskCollator(object):
def __init__(
self,
cfgs_mask,
dataset_fpcs,
crop_size=(224, 224),
patch_size=(16, 16),
tubelet_size=2,
):
super(MaskCollator, self).__init__()
self.mask_generators = dict()
for fpc in dataset_fpcs:
self.mask_generators[fpc] = []
for m in cfgs_mask:
mask_generator = _MaskGenerator(
crop_size=crop_size,
num_frames=fpc,
spatial_patch_size=patch_size,
temporal_patch_size=tubelet_size,
spatial_pred_mask_scale=m.get("spatial_scale"),
temporal_pred_mask_scale=m.get("temporal_scale"),
aspect_ratio=m.get("aspect_ratio"),
npred=m.get("num_blocks"),
max_context_frames_ratio=m.get("max_temporal_keep", 1.0),
max_keep=m.get("max_keep", None),
full_complement=m.get("full_complement", False),
pred_full_complement=m.get("pred_full_complement", False),
inv_block=m.get("inv_block", False),
)
self.mask_generators[fpc].append(mask_generator)
def step(self):
for fpc in self.mask_generators:
for mask_generator in self.mask_generators[fpc]:
mask_generator.step()
def __call__(self, batch):
# Batch: [buffer, label, clip_indices]
filtered_batches = {fpc: [] for fpc in self.mask_generators}
for sample in batch:
fpc = len(sample[-1][-1])
filtered_batches[fpc] += [sample]
fpc_collations = []
for fpc in filtered_batches:
fpc_batch = filtered_batches[fpc]
batch_size = len(fpc_batch)
if batch_size == 0:
continue
collated_batch = torch.utils.data.default_collate(fpc_batch)
collated_masks_pred, collated_masks_enc = [], []
for i, mask_generator in enumerate(self.mask_generators[fpc]):
masks_enc, masks_pred = mask_generator(batch_size)
collated_masks_enc.append(masks_enc)
collated_masks_pred.append(masks_pred)
fpc_collations += [(collated_batch, collated_masks_enc, collated_masks_pred)]
return fpc_collations
class _MaskGenerator(object):
def __init__(
self,
crop_size=(224, 224),
num_frames=16,
spatial_patch_size=(16, 16),
temporal_patch_size=2,
spatial_pred_mask_scale=(0.2, 0.8),
temporal_pred_mask_scale=(1.0, 1.0),
aspect_ratio=(0.3, 3.0),
npred=1,
max_context_frames_ratio=1.0,
max_keep=None,
inv_block=False,
full_complement=False,
pred_full_complement=False,
):
super(_MaskGenerator, self).__init__()
if not isinstance(crop_size, tuple):
crop_size = (crop_size,) * 2
if not isinstance(spatial_patch_size, tuple):
spatial_patch_size = (spatial_patch_size,) * 2
self.crop_size = crop_size
self.height, self.width = [crop_size[i] // spatial_patch_size[i] for i in (0, 1)]
self.duration = num_frames // temporal_patch_size
self.full_complement = full_complement
self.pred_full_complement = pred_full_complement
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.aspect_ratio = aspect_ratio
self.spatial_pred_mask_scale = spatial_pred_mask_scale
self.temporal_pred_mask_scale = temporal_pred_mask_scale
self.npred = npred
self.max_context_duration = max(
1, int(self.duration * max_context_frames_ratio)
) # maximum number of time-steps (frames) spanned by context mask
self.max_keep = max_keep # maximum number of patches to keep in context
self._itr_counter = Value("i", -1) # collator is shared across worker processes
self.inv_block = inv_block
def step(self):
i = self._itr_counter
with i.get_lock():
i.value += 1
v = i.value
return v
def _sample_block_size(self, generator, temporal_scale, spatial_scale, aspect_ratio_scale):
# -- Sample temporal block mask scale
_rand = torch.rand(1, generator=generator).item()
min_t, max_t = temporal_scale
temporal_mask_scale = min_t + _rand * (max_t - min_t)
t = max(1, int(self.duration * temporal_mask_scale))
# -- Sample spatial block mask scale
_rand = torch.rand(1, generator=generator).item()
min_s, max_s = spatial_scale
spatial_mask_scale = min_s + _rand * (max_s - min_s)
spatial_num_keep = int(self.height * self.width * spatial_mask_scale)
# -- Sample block aspect-ratio
_rand = torch.rand(1, generator=generator).item()
min_ar, max_ar = aspect_ratio_scale
aspect_ratio = min_ar + _rand * (max_ar - min_ar)
# -- Compute block height and width (given scale and aspect-ratio)
h = int(round(math.sqrt(spatial_num_keep * aspect_ratio)))
w = int(round(math.sqrt(spatial_num_keep / aspect_ratio)))
h = min(h, self.height)
w = min(w, self.width)
return (t, h, w)
def _sample_block_mask(self, b_size):
t, h, w = b_size
top = torch.randint(0, self.height - h + 1, (1,))
left = torch.randint(0, self.width - w + 1, (1,))
start = torch.randint(0, self.duration - t + 1, (1,))
mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32)
mask[start : start + t, top : top + h, left : left + w] = 0
# Context mask will only span the first X frames
# (X=self.max_context_frames)
if self.max_context_duration < self.duration:
mask[self.max_context_duration :, :, :] = 0
# --
return mask
def __call__(self, batch_size):
"""
Create encoder and predictor masks when collating imgs into a batch
# 1. sample pred block size using seed
# 2. sample several pred block locations for each image (w/o seed)
# 3. return pred masks and complement (enc mask)
"""
seed = self.step()
g = torch.Generator()
g.manual_seed(seed)
p_size = self._sample_block_size(
generator=g,
temporal_scale=self.temporal_pred_mask_scale,
spatial_scale=self.spatial_pred_mask_scale,
aspect_ratio_scale=self.aspect_ratio,
)
collated_masks_pred, collated_masks_enc = [], []
min_keep_enc = min_keep_pred = self.duration * self.height * self.width
for _ in range(batch_size):
empty_context = True
while empty_context:
mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32)
for _ in range(self.npred):
mask_e *= self._sample_block_mask(p_size)
mask_e = mask_e.flatten()
mask_p = torch.argwhere(mask_e == 0).squeeze()
mask_e = torch.nonzero(mask_e).squeeze()
empty_context = len(mask_e) == 0
if not empty_context:
min_keep_pred = min(min_keep_pred, len(mask_p))
min_keep_enc = min(min_keep_enc, len(mask_e))
collated_masks_pred.append(mask_p)
collated_masks_enc.append(mask_e)
if self.max_keep is not None:
min_keep_enc = min(min_keep_enc, self.max_keep)
collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc]
collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred]
if self.full_complement: # predictor mask is just complement of encoder mask
collated_masks_pred = [
torch.tensor(
sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))),
dtype=cm.dtype,
)
for cm in collated_masks_enc
]
elif self.pred_full_complement:
collated_masks_enc = [
torch.tensor(
sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))),
dtype=cm.dtype,
)
for cm in collated_masks_pred
]
collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)
collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
if self.inv_block:
return collated_masks_pred, collated_masks_enc # predict context from block
else:
return collated_masks_enc, collated_masks_pred