|
import numpy as np |
|
import random |
|
import ast |
|
|
|
class TubeMaskingGenerator: |
|
def __init__(self, input_size, mask_ratio): |
|
self.frames, self.height, self.width = input_size |
|
self.num_patches_per_frame = self.height * self.width |
|
self.total_patches = self.frames * self.num_patches_per_frame |
|
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) |
|
self.total_masks = self.frames * self.num_masks_per_frame |
|
|
|
def __repr__(self): |
|
repr_str = "Maks: total patches {}, mask patches {}".format( |
|
self.total_patches, self.total_masks |
|
) |
|
return repr_str |
|
|
|
def __call__(self): |
|
mask_per_frame = np.hstack([ |
|
np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), |
|
np.ones(self.num_masks_per_frame), |
|
]) |
|
np.random.shuffle(mask_per_frame) |
|
mask = np.tile(mask_per_frame, (self.frames,1)).flatten() |
|
return mask |
|
|
|
|
|
class TubeletMaskingGenerator: |
|
def __init__(self, input_size, mask_ratio, visible_frames, mask_type="tube", traj_unmask_ratio=0.1): |
|
self.tube_masking_generator = TubeMaskingGenerator(input_size, mask_ratio) |
|
self.frames, self.height, self.width = input_size |
|
self.num_patches_per_frame = self.height * self.width |
|
self.total_patches = self.frames * self.num_patches_per_frame |
|
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) |
|
self.total_masks = self.frames * self.num_masks_per_frame |
|
self.patch_size = 16 |
|
self.traj_unmask_ratio = traj_unmask_ratio |
|
if visible_frames is not None: |
|
visible_list = ast.literal_eval(visible_frames) |
|
self.visible_frames = [int(element) for element in visible_list] |
|
else: |
|
self.visible_frames = None |
|
|
|
self.mask_type = mask_type |
|
|
|
def _balance_num_masks(self, combined_mask, |
|
unmasked_object_patches_index, |
|
unmasked_non_object_patches_index, |
|
masked_object_patches_index, |
|
tube_masked_index=None, |
|
tube_unmasked_index=None): |
|
current_masks = np.sum(combined_mask) |
|
num_diff = np.abs(self.total_masks - current_masks) |
|
|
|
if tube_masked_index is None or tube_unmasked_index is None: |
|
|
|
|
|
if current_masks > self.total_masks: |
|
picked_index = masked_object_patches_index[np.random.choice(masked_object_patches_index.size, size=int(num_diff), replace=False)] |
|
combined_mask[picked_index] = 0. |
|
|
|
elif current_masks < self.total_masks: |
|
if num_diff <= len(unmasked_non_object_patches_index): |
|
picked_index = unmasked_non_object_patches_index[np.random.choice(unmasked_non_object_patches_index.size, size=int(num_diff), replace=False)] |
|
combined_mask[picked_index] = 1. |
|
else: |
|
combined_mask[unmasked_non_object_patches_index] = 1. |
|
picked_index = unmasked_object_patches_index[np.random.choice(unmasked_object_patches_index.size, size=int(num_diff - len(unmasked_non_object_patches_index)), replace=False)] |
|
combined_mask[picked_index] = 1. |
|
else: |
|
|
|
tube_masked_non_object_index = np.array(list(set(tube_masked_index) - set(masked_object_patches_index) - set(unmasked_object_patches_index))) |
|
if current_masks > self.total_masks: |
|
if num_diff <= len(tube_masked_non_object_index): |
|
picked_index = tube_masked_non_object_index[np.random.choice(tube_masked_non_object_index.size, size=int(num_diff), replace=False)] |
|
combined_mask[picked_index] = 0. |
|
else: |
|
combined_mask[tube_masked_non_object_index] = 0. |
|
picked_index = masked_object_patches_index[np.random.choice(masked_object_patches_index.size, size=int(num_diff - len(tube_masked_non_object_index)), replace=False)] |
|
combined_mask[picked_index] = 0. |
|
|
|
elif current_masks < self.total_masks: |
|
tube_unmasked_non_object_index = np.array(list(set(tube_unmasked_index) - set(masked_object_patches_index) - set(unmasked_object_patches_index))) |
|
if num_diff <= len(tube_unmasked_non_object_index): |
|
picked_index = tube_unmasked_non_object_index[np.random.choice(tube_unmasked_non_object_index.size, size=int(num_diff), replace=False)] |
|
combined_mask[picked_index] = 1. |
|
else: |
|
combined_mask[tube_unmasked_non_object_index] = 1. |
|
picked_index = unmasked_object_patches_index[np.random.choice(unmasked_object_patches_index.size, size=int(num_diff - len(tube_unmasked_non_object_index)), replace=False)] |
|
combined_mask[picked_index] = 1. |
|
|
|
balanced_mask = combined_mask |
|
return balanced_mask |
|
|
|
def __repr__(self): |
|
repr_str = "Maks: total patches {}, mask patches {}".format( |
|
self.total_patches, self.total_masks |
|
) |
|
return repr_str |
|
|
|
|
|
def __call__(self, traj_rois): |
|
|
|
tube_mask = self.tube_masking_generator() |
|
tube_masked_index = None |
|
tube_unmasked_index = None |
|
|
|
|
|
num_tubelet, num_frame, box = traj_rois.shape |
|
assert num_frame % 2 == 0 and self.frames == (num_frame // 2) |
|
combined_mask = np.zeros((num_frame // 2, self.height, self.width)) |
|
|
|
|
|
|
|
|
|
if self.visible_frames is None: |
|
picked_frame = np.random.randint(0, (num_frame // 2)) |
|
picked_list = [picked_frame] |
|
else: |
|
picked_list = self.visible_frames |
|
|
|
|
|
for roi_idx, roi in enumerate(traj_rois): |
|
for i in range(num_frame // 2): |
|
min_x = min( (roi[2 * i][0], roi[2 * i + 1][0]) ) |
|
max_x = max( (roi[2 * i][2], roi[2 * i + 1][2]) ) |
|
min_y = min( (roi[2 * i][1], roi[2 * i + 1][1]) ) |
|
max_y = max( (roi[2 * i][3], roi[2 * i + 1][3]) ) |
|
|
|
patch_index_x_min = max( int(np.floor(min_x / self.patch_size)), 0) |
|
patch_index_x_max = min( int(np.ceil(max_x / self.patch_size)) + 1, 14) |
|
patch_index_y_min = max( int(np.floor(min_y / self.patch_size)), 0) |
|
patch_index_y_max = min( int(np.ceil(max_y / self.patch_size)) + 1, 14) |
|
|
|
if i in picked_list: |
|
combined_mask[i][patch_index_y_min:patch_index_y_max, patch_index_x_min:patch_index_x_max] = 2. |
|
else: |
|
combined_mask[i][patch_index_y_min:patch_index_y_max, patch_index_x_min:patch_index_x_max] = 1. |
|
|
|
combined_mask = combined_mask.flatten() |
|
masked_object_patches_index = np.where(combined_mask == 1.)[0] |
|
unmasked_non_object_patches_index = np.where(combined_mask == 0.)[0] |
|
unmasked_object_patches_index = np.where(combined_mask == 2.)[0] |
|
combined_mask[unmasked_object_patches_index] = 0. |
|
|
|
tube_masked_index = np.where(tube_mask == 1.)[0] |
|
tube_unmasked_index = np.where(tube_mask == 0.)[0] |
|
|
|
|
|
combined_mask = np.bitwise_or(combined_mask.astype(bool), tube_mask.astype(bool)).astype(np.float32) |
|
|
|
if self.mask_type == "tube+picked_frame_visible": |
|
|
|
combined_mask[unmasked_object_patches_index] = 0. |
|
|
|
elif self.mask_type == "tube+traj_mask": |
|
|
|
traj_unmask_ratio = self.traj_unmask_ratio |
|
traj_patches_index = np.array(list(set(masked_object_patches_index) | set(unmasked_object_patches_index))) |
|
unmasked_traj_patches_index = traj_patches_index[np.random.choice(traj_patches_index.size, size=int(traj_unmask_ratio * len(traj_patches_index)), replace=False)] |
|
|
|
|
|
combined_mask[traj_patches_index] = 1. |
|
|
|
combined_mask[unmasked_traj_patches_index] = 0. |
|
|
|
|
|
unmasked_object_patches_index = unmasked_traj_patches_index |
|
masked_object_patches_index = np.array(list(set(traj_patches_index) - set(unmasked_traj_patches_index))) |
|
|
|
|
|
|
|
mask = self._balance_num_masks(combined_mask, |
|
unmasked_object_patches_index, |
|
unmasked_non_object_patches_index, |
|
masked_object_patches_index, |
|
tube_masked_index, |
|
tube_unmasked_index) |
|
|
|
|
|
assert np.sum(mask) == self.total_masks |
|
return mask |
|
|
|
|
|
|