|
|
|
|
|
|
|
|
|
|
|
import math |
|
from logging import getLogger |
|
from multiprocessing import Value |
|
|
|
import torch |
|
|
|
_GLOBAL_SEED = 0 |
|
logger = getLogger() |
|
|
|
|
|
class MaskCollator(object): |
|
|
|
def __init__( |
|
self, |
|
cfgs_mask, |
|
dataset_fpcs, |
|
crop_size=(224, 224), |
|
patch_size=(16, 16), |
|
tubelet_size=2, |
|
): |
|
super(MaskCollator, self).__init__() |
|
|
|
self.mask_generators = dict() |
|
for fpc in dataset_fpcs: |
|
self.mask_generators[fpc] = [] |
|
for m in cfgs_mask: |
|
mask_generator = _MaskGenerator( |
|
crop_size=crop_size, |
|
num_frames=fpc, |
|
spatial_patch_size=patch_size, |
|
temporal_patch_size=tubelet_size, |
|
spatial_pred_mask_scale=m.get("spatial_scale"), |
|
temporal_pred_mask_scale=m.get("temporal_scale"), |
|
aspect_ratio=m.get("aspect_ratio"), |
|
npred=m.get("num_blocks"), |
|
max_context_frames_ratio=m.get("max_temporal_keep", 1.0), |
|
max_keep=m.get("max_keep", None), |
|
full_complement=m.get("full_complement", False), |
|
pred_full_complement=m.get("pred_full_complement", False), |
|
inv_block=m.get("inv_block", False), |
|
) |
|
self.mask_generators[fpc].append(mask_generator) |
|
|
|
def step(self): |
|
for fpc in self.mask_generators: |
|
for mask_generator in self.mask_generators[fpc]: |
|
mask_generator.step() |
|
|
|
def __call__(self, batch): |
|
|
|
|
|
filtered_batches = {fpc: [] for fpc in self.mask_generators} |
|
for sample in batch: |
|
fpc = len(sample[-1][-1]) |
|
filtered_batches[fpc] += [sample] |
|
|
|
fpc_collations = [] |
|
for fpc in filtered_batches: |
|
fpc_batch = filtered_batches[fpc] |
|
batch_size = len(fpc_batch) |
|
if batch_size == 0: |
|
continue |
|
collated_batch = torch.utils.data.default_collate(fpc_batch) |
|
collated_masks_pred, collated_masks_enc = [], [] |
|
for i, mask_generator in enumerate(self.mask_generators[fpc]): |
|
masks_enc, masks_pred = mask_generator(batch_size) |
|
collated_masks_enc.append(masks_enc) |
|
collated_masks_pred.append(masks_pred) |
|
fpc_collations += [(collated_batch, collated_masks_enc, collated_masks_pred)] |
|
|
|
return fpc_collations |
|
|
|
|
|
class _MaskGenerator(object): |
|
|
|
def __init__( |
|
self, |
|
crop_size=(224, 224), |
|
num_frames=16, |
|
spatial_patch_size=(16, 16), |
|
temporal_patch_size=2, |
|
spatial_pred_mask_scale=(0.2, 0.8), |
|
temporal_pred_mask_scale=(1.0, 1.0), |
|
aspect_ratio=(0.3, 3.0), |
|
npred=1, |
|
max_context_frames_ratio=1.0, |
|
max_keep=None, |
|
inv_block=False, |
|
full_complement=False, |
|
pred_full_complement=False, |
|
): |
|
super(_MaskGenerator, self).__init__() |
|
if not isinstance(crop_size, tuple): |
|
crop_size = (crop_size,) * 2 |
|
if not isinstance(spatial_patch_size, tuple): |
|
spatial_patch_size = (spatial_patch_size,) * 2 |
|
self.crop_size = crop_size |
|
self.height, self.width = [crop_size[i] // spatial_patch_size[i] for i in (0, 1)] |
|
self.duration = num_frames // temporal_patch_size |
|
self.full_complement = full_complement |
|
self.pred_full_complement = pred_full_complement |
|
|
|
self.spatial_patch_size = spatial_patch_size |
|
self.temporal_patch_size = temporal_patch_size |
|
|
|
self.aspect_ratio = aspect_ratio |
|
self.spatial_pred_mask_scale = spatial_pred_mask_scale |
|
self.temporal_pred_mask_scale = temporal_pred_mask_scale |
|
self.npred = npred |
|
self.max_context_duration = max( |
|
1, int(self.duration * max_context_frames_ratio) |
|
) |
|
self.max_keep = max_keep |
|
self._itr_counter = Value("i", -1) |
|
self.inv_block = inv_block |
|
|
|
def step(self): |
|
i = self._itr_counter |
|
with i.get_lock(): |
|
i.value += 1 |
|
v = i.value |
|
return v |
|
|
|
def _sample_block_size(self, generator, temporal_scale, spatial_scale, aspect_ratio_scale): |
|
|
|
_rand = torch.rand(1, generator=generator).item() |
|
min_t, max_t = temporal_scale |
|
temporal_mask_scale = min_t + _rand * (max_t - min_t) |
|
t = max(1, int(self.duration * temporal_mask_scale)) |
|
|
|
|
|
_rand = torch.rand(1, generator=generator).item() |
|
min_s, max_s = spatial_scale |
|
spatial_mask_scale = min_s + _rand * (max_s - min_s) |
|
spatial_num_keep = int(self.height * self.width * spatial_mask_scale) |
|
|
|
|
|
_rand = torch.rand(1, generator=generator).item() |
|
min_ar, max_ar = aspect_ratio_scale |
|
aspect_ratio = min_ar + _rand * (max_ar - min_ar) |
|
|
|
|
|
h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) |
|
w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) |
|
h = min(h, self.height) |
|
w = min(w, self.width) |
|
|
|
return (t, h, w) |
|
|
|
def _sample_block_mask(self, b_size): |
|
t, h, w = b_size |
|
top = torch.randint(0, self.height - h + 1, (1,)) |
|
left = torch.randint(0, self.width - w + 1, (1,)) |
|
start = torch.randint(0, self.duration - t + 1, (1,)) |
|
|
|
mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) |
|
mask[start : start + t, top : top + h, left : left + w] = 0 |
|
|
|
|
|
|
|
if self.max_context_duration < self.duration: |
|
mask[self.max_context_duration :, :, :] = 0 |
|
|
|
|
|
return mask |
|
|
|
def __call__(self, batch_size): |
|
""" |
|
Create encoder and predictor masks when collating imgs into a batch |
|
# 1. sample pred block size using seed |
|
# 2. sample several pred block locations for each image (w/o seed) |
|
# 3. return pred masks and complement (enc mask) |
|
""" |
|
seed = self.step() |
|
g = torch.Generator() |
|
g.manual_seed(seed) |
|
p_size = self._sample_block_size( |
|
generator=g, |
|
temporal_scale=self.temporal_pred_mask_scale, |
|
spatial_scale=self.spatial_pred_mask_scale, |
|
aspect_ratio_scale=self.aspect_ratio, |
|
) |
|
|
|
collated_masks_pred, collated_masks_enc = [], [] |
|
min_keep_enc = min_keep_pred = self.duration * self.height * self.width |
|
for _ in range(batch_size): |
|
|
|
empty_context = True |
|
while empty_context: |
|
|
|
mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) |
|
for _ in range(self.npred): |
|
mask_e *= self._sample_block_mask(p_size) |
|
mask_e = mask_e.flatten() |
|
|
|
mask_p = torch.argwhere(mask_e == 0).squeeze() |
|
mask_e = torch.nonzero(mask_e).squeeze() |
|
|
|
empty_context = len(mask_e) == 0 |
|
if not empty_context: |
|
min_keep_pred = min(min_keep_pred, len(mask_p)) |
|
min_keep_enc = min(min_keep_enc, len(mask_e)) |
|
collated_masks_pred.append(mask_p) |
|
collated_masks_enc.append(mask_e) |
|
|
|
if self.max_keep is not None: |
|
min_keep_enc = min(min_keep_enc, self.max_keep) |
|
|
|
collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] |
|
collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] |
|
if self.full_complement: |
|
collated_masks_pred = [ |
|
torch.tensor( |
|
sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))), |
|
dtype=cm.dtype, |
|
) |
|
for cm in collated_masks_enc |
|
] |
|
elif self.pred_full_complement: |
|
collated_masks_enc = [ |
|
torch.tensor( |
|
sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))), |
|
dtype=cm.dtype, |
|
) |
|
for cm in collated_masks_pred |
|
] |
|
|
|
collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) |
|
collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) |
|
|
|
if self.inv_block: |
|
return collated_masks_pred, collated_masks_enc |
|
else: |
|
return collated_masks_enc, collated_masks_pred |
|
|