File size: 1,082 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
"""Utility functions to manipulate feature extractors."""
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.fx.graph_module import GraphModule
from .timm import TimmFeatureExtractor
def dryrun_find_featuremap_dims(
feature_extractor: TimmFeatureExtractor | GraphModule,
input_size: tuple[int, int],
layers: list[str],
) -> dict[str, dict[str, int | tuple[int, int]]]:
"""Dry run an empty image of `input_size` size to get the featuremap tensors' dimensions (num_features, resolution).
Returns:
tuple[int, int]: maping of `layer -> dimensions dict`
Each `dimension dict` has two keys: `num_features` (int) and `resolution`(tuple[int, int]).
"""
device = next(feature_extractor.parameters()).device
dryrun_input = torch.empty(1, 3, *input_size).to(device)
dryrun_features = feature_extractor(dryrun_input)
return {
layer: {"num_features": dryrun_features[layer].shape[1], "resolution": dryrun_features[layer].shape[2:]}
for layer in layers
}
|