VarunKodathala's picture
Upload folder using huggingface_hub
0e37bb2 verified
raw
history blame contribute delete
660 Bytes
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def apply_masks(x, masks, concat=True):
"""
:param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
:param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
"""
all_x = []
for m in masks:
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x += [torch.gather(x, dim=1, index=mask_keep)]
if not concat:
return all_x
return torch.cat(all_x, dim=0)