LogSAD / anomalib /utils /cv /connected_components.py
zhiqing0205
Add core libraries: anomalib, dinov2, open_clip_local
3de7bf6
raw
history blame
1.76 kB
"""Connected component labeling."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import cv2
import numpy as np
import torch
from kornia.contrib import connected_components
def connected_components_gpu(image: torch.Tensor, num_iterations: int = 1000) -> torch.Tensor:
"""Perform connected component labeling on GPU and remap the labels from 0 to N.
Args:
image (torch.Tensor): Binary input image from which we want to extract connected components (Bx1xHxW)
num_iterations (int): Number of iterations used in the connected component computation.
Returns:
Tensor: Components labeled from 0 to N.
"""
components = connected_components(image, num_iterations=num_iterations)
# remap component values from 0 to N
labels = components.unique()
for new_label, old_label in enumerate(labels):
components[components == old_label] = new_label
return components.int()
def connected_components_cpu(image: torch.Tensor) -> torch.Tensor:
"""Perform connected component labeling on CPU.
Args:
image (torch.Tensor): Binary input data from which we want to extract connected components (Bx1xHxW)
Returns:
Tensor: Components labeled from 0 to N.
"""
components = torch.zeros_like(image)
label_idx = 1
for i, msk in enumerate(image):
mask = msk.squeeze().cpu().numpy().astype(np.uint8)
_, comps = cv2.connectedComponents(mask)
# remap component values to make sure every component has a unique value when outputs are concatenated
for label in np.unique(comps)[1:]:
components[i, 0, ...][np.where(comps == label)] = label_idx
label_idx += 1
return components.int()