"""Utility functions to manipulate feature extractors.""" | |
# Copyright (C) 2022-2024 Intel Corporation | |
# SPDX-License-Identifier: Apache-2.0 | |
import torch | |
from torch.fx.graph_module import GraphModule | |
from .timm import TimmFeatureExtractor | |
def dryrun_find_featuremap_dims( | |
feature_extractor: TimmFeatureExtractor | GraphModule, | |
input_size: tuple[int, int], | |
layers: list[str], | |
) -> dict[str, dict[str, int | tuple[int, int]]]: | |
"""Dry run an empty image of `input_size` size to get the featuremap tensors' dimensions (num_features, resolution). | |
Returns: | |
tuple[int, int]: maping of `layer -> dimensions dict` | |
Each `dimension dict` has two keys: `num_features` (int) and `resolution`(tuple[int, int]). | |
""" | |
device = next(feature_extractor.parameters()).device | |
dryrun_input = torch.empty(1, 3, *input_size).to(device) | |
dryrun_features = feature_extractor(dryrun_input) | |
return { | |
layer: {"num_features": dryrun_features[layer].shape[1], "resolution": dryrun_features[layer].shape[2:]} | |
for layer in layers | |
} | |