SMILE / masking_generator.py
fmthoker's picture
Upload 26 files
4940c8b verified
raw
history blame
9.96 kB
import numpy as np
import random
import ast
class TubeMaskingGenerator:
def __init__(self, input_size, mask_ratio):
self.frames, self.height, self.width = input_size
self.num_patches_per_frame = self.height * self.width
self.total_patches = self.frames * self.num_patches_per_frame
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
self.total_masks = self.frames * self.num_masks_per_frame
def __repr__(self):
repr_str = "Maks: total patches {}, mask patches {}".format(
self.total_patches, self.total_masks
)
return repr_str
def __call__(self):
mask_per_frame = np.hstack([
np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
np.ones(self.num_masks_per_frame),
])
np.random.shuffle(mask_per_frame)
mask = np.tile(mask_per_frame, (self.frames,1)).flatten()
return mask
class TubeletMaskingGenerator:
def __init__(self, input_size, mask_ratio, visible_frames, mask_type="tube", traj_unmask_ratio=0.1):
self.tube_masking_generator = TubeMaskingGenerator(input_size, mask_ratio)
self.frames, self.height, self.width = input_size
self.num_patches_per_frame = self.height * self.width
self.total_patches = self.frames * self.num_patches_per_frame
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
self.total_masks = self.frames * self.num_masks_per_frame
self.patch_size = 16
self.traj_unmask_ratio = traj_unmask_ratio
if visible_frames is not None:
visible_list = ast.literal_eval(visible_frames)
self.visible_frames = [int(element) for element in visible_list]
else:
self.visible_frames = None
self.mask_type = mask_type
def _balance_num_masks(self, combined_mask,
unmasked_object_patches_index,
unmasked_non_object_patches_index,
masked_object_patches_index,
tube_masked_index=None,
tube_unmasked_index=None):
current_masks = np.sum(combined_mask)
num_diff = np.abs(self.total_masks - current_masks)
if tube_masked_index is None or tube_unmasked_index is None:
# tubelet masking without tube mask
# if too many masked patches, we unmask some patches
if current_masks > self.total_masks:
picked_index = masked_object_patches_index[np.random.choice(masked_object_patches_index.size, size=int(num_diff), replace=False)]
combined_mask[picked_index] = 0.
# if too few masked patches, we first try to mask non-object patches, if not enough, then we mask protected patches
elif current_masks < self.total_masks:
if num_diff <= len(unmasked_non_object_patches_index):
picked_index = unmasked_non_object_patches_index[np.random.choice(unmasked_non_object_patches_index.size, size=int(num_diff), replace=False)]
combined_mask[picked_index] = 1.
else:
combined_mask[unmasked_non_object_patches_index] = 1.
picked_index = unmasked_object_patches_index[np.random.choice(unmasked_object_patches_index.size, size=int(num_diff - len(unmasked_non_object_patches_index)), replace=False)]
combined_mask[picked_index] = 1.
else:
# if too many masked patches, we first try to unmask tube masked patches, if not enough, then we unmask object patches
tube_masked_non_object_index = np.array(list(set(tube_masked_index) - set(masked_object_patches_index) - set(unmasked_object_patches_index)))
if current_masks > self.total_masks:
if num_diff <= len(tube_masked_non_object_index):
picked_index = tube_masked_non_object_index[np.random.choice(tube_masked_non_object_index.size, size=int(num_diff), replace=False)]
combined_mask[picked_index] = 0.
else:
combined_mask[tube_masked_non_object_index] = 0.
picked_index = masked_object_patches_index[np.random.choice(masked_object_patches_index.size, size=int(num_diff - len(tube_masked_non_object_index)), replace=False)]
combined_mask[picked_index] = 0.
# if too few masked patches, we first try to mask non-object patches, if not enough, then we mask protected patches
elif current_masks < self.total_masks:
tube_unmasked_non_object_index = np.array(list(set(tube_unmasked_index) - set(masked_object_patches_index) - set(unmasked_object_patches_index)))
if num_diff <= len(tube_unmasked_non_object_index):
picked_index = tube_unmasked_non_object_index[np.random.choice(tube_unmasked_non_object_index.size, size=int(num_diff), replace=False)]
combined_mask[picked_index] = 1.
else:
combined_mask[tube_unmasked_non_object_index] = 1.
picked_index = unmasked_object_patches_index[np.random.choice(unmasked_object_patches_index.size, size=int(num_diff - len(tube_unmasked_non_object_index)), replace=False)]
combined_mask[picked_index] = 1.
balanced_mask = combined_mask
return balanced_mask
def __repr__(self):
repr_str = "Maks: total patches {}, mask patches {}".format(
self.total_patches, self.total_masks
)
return repr_str
# 1 in mask array means masked, 0 means unmasked
def __call__(self, traj_rois):
# generate original VideoMAE tube mask and intialize the tube mask index
tube_mask = self.tube_masking_generator()
tube_masked_index = None
tube_unmasked_index = None
# initialize mask
num_tubelet, num_frame, box = traj_rois.shape
assert num_frame % 2 == 0 and self.frames == (num_frame // 2)
combined_mask = np.zeros((num_frame // 2, self.height, self.width))
# assume patch size is (2, 16, 16) so mask shape should be (8, 14, 14)
# we combine the traj_rois of two consecutive frames to one large traj_rois
# pick one tubelet that is not masked
if self.visible_frames is None:
picked_frame = np.random.randint(0, (num_frame // 2))
picked_list = [picked_frame]
else:
picked_list = self.visible_frames
# combined mask 1 means object patches that should be masked, 2 means object patches that should not be masked, 0 means non-object patches
for roi_idx, roi in enumerate(traj_rois):
for i in range(num_frame // 2):
min_x = min( (roi[2 * i][0], roi[2 * i + 1][0]) )
max_x = max( (roi[2 * i][2], roi[2 * i + 1][2]) )
min_y = min( (roi[2 * i][1], roi[2 * i + 1][1]) )
max_y = max( (roi[2 * i][3], roi[2 * i + 1][3]) )
patch_index_x_min = max( int(np.floor(min_x / self.patch_size)), 0)
patch_index_x_max = min( int(np.ceil(max_x / self.patch_size)) + 1, 14)
patch_index_y_min = max( int(np.floor(min_y / self.patch_size)), 0)
patch_index_y_max = min( int(np.ceil(max_y / self.patch_size)) + 1, 14)
if i in picked_list:
combined_mask[i][patch_index_y_min:patch_index_y_max, patch_index_x_min:patch_index_x_max] = 2.
else:
combined_mask[i][patch_index_y_min:patch_index_y_max, patch_index_x_min:patch_index_x_max] = 1.
combined_mask = combined_mask.flatten()
masked_object_patches_index = np.where(combined_mask == 1.)[0]
unmasked_non_object_patches_index = np.where(combined_mask == 0.)[0]
unmasked_object_patches_index = np.where(combined_mask == 2.)[0]
combined_mask[unmasked_object_patches_index] = 0.
tube_masked_index = np.where(tube_mask == 1.)[0]
tube_unmasked_index = np.where(tube_mask == 0.)[0]
# combine tubelet mask and tube mask
combined_mask = np.bitwise_or(combined_mask.astype(bool), tube_mask.astype(bool)).astype(np.float32)
if self.mask_type == "tube+picked_frame_visible":
# unmasked the protected patches
combined_mask[unmasked_object_patches_index] = 0.
elif self.mask_type == "tube+traj_mask":
# get index of unmasked traj patches
traj_unmask_ratio = self.traj_unmask_ratio
traj_patches_index = np.array(list(set(masked_object_patches_index) | set(unmasked_object_patches_index)))
unmasked_traj_patches_index = traj_patches_index[np.random.choice(traj_patches_index.size, size=int(traj_unmask_ratio * len(traj_patches_index)), replace=False)]
# mask the whole traj
combined_mask[traj_patches_index] = 1.
# unmask those selected patches
combined_mask[unmasked_traj_patches_index] = 0.
# update indexes
unmasked_object_patches_index = unmasked_traj_patches_index
masked_object_patches_index = np.array(list(set(traj_patches_index) - set(unmasked_traj_patches_index)))
# balance masked patch number
mask = self._balance_num_masks(combined_mask,
unmasked_object_patches_index,
unmasked_non_object_patches_index,
masked_object_patches_index,
tube_masked_index,
tube_unmasked_index)
assert np.sum(mask) == self.total_masks
return mask