|
"""Connected component labeling.""" |
|
|
|
|
|
|
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from kornia.contrib import connected_components |
|
|
|
|
|
def connected_components_gpu(image: torch.Tensor, num_iterations: int = 1000) -> torch.Tensor: |
|
"""Perform connected component labeling on GPU and remap the labels from 0 to N. |
|
|
|
Args: |
|
image (torch.Tensor): Binary input image from which we want to extract connected components (Bx1xHxW) |
|
num_iterations (int): Number of iterations used in the connected component computation. |
|
|
|
Returns: |
|
Tensor: Components labeled from 0 to N. |
|
""" |
|
components = connected_components(image, num_iterations=num_iterations) |
|
|
|
|
|
labels = components.unique() |
|
for new_label, old_label in enumerate(labels): |
|
components[components == old_label] = new_label |
|
|
|
return components.int() |
|
|
|
|
|
def connected_components_cpu(image: torch.Tensor) -> torch.Tensor: |
|
"""Perform connected component labeling on CPU. |
|
|
|
Args: |
|
image (torch.Tensor): Binary input data from which we want to extract connected components (Bx1xHxW) |
|
|
|
Returns: |
|
Tensor: Components labeled from 0 to N. |
|
""" |
|
components = torch.zeros_like(image) |
|
label_idx = 1 |
|
for i, msk in enumerate(image): |
|
mask = msk.squeeze().cpu().numpy().astype(np.uint8) |
|
_, comps = cv2.connectedComponents(mask) |
|
|
|
for label in np.unique(comps)[1:]: |
|
components[i, 0, ...][np.where(comps == label)] = label_idx |
|
label_idx += 1 |
|
return components.int() |
|
|