"""k-Center Greedy Method. Returns points that minimizes the maximum distance of any point to a center. - https://arxiv.org/abs/1708.00489 """ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import torch from rich.progress import track from torch.nn import functional as F # noqa: N812 from anomalib.models.components.dimensionality_reduction import SparseRandomProjection class KCenterGreedy: """Implements k-center-greedy method. Args: embedding (torch.Tensor): Embedding vector extracted from a CNN sampling_ratio (float): Ratio to choose coreset size from the embedding size. Example: >>> embedding.shape torch.Size([219520, 1536]) >>> sampler = KCenterGreedy(embedding=embedding) >>> sampled_idxs = sampler.select_coreset_idxs() >>> coreset = embedding[sampled_idxs] >>> coreset.shape torch.Size([219, 1536]) """ def __init__(self, embedding: torch.Tensor, sampling_ratio: float) -> None: self.embedding = embedding self.coreset_size = int(embedding.shape[0] * sampling_ratio) self.model = SparseRandomProjection(eps=0.9) self.features: torch.Tensor self.min_distances: torch.Tensor = None self.n_observations = self.embedding.shape[0] def reset_distances(self) -> None: """Reset minimum distances.""" self.min_distances = None def update_distances(self, cluster_centers: list[int]) -> None: """Update min distances given cluster centers. Args: cluster_centers (list[int]): indices of cluster centers """ if cluster_centers: centers = self.features[cluster_centers] distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1) if self.min_distances is None: self.min_distances = distance else: self.min_distances = torch.minimum(self.min_distances, distance) def get_new_idx(self) -> int: """Get index value of a sample. Based on minimum distance of the cluster Returns: int: Sample index """ if isinstance(self.min_distances, torch.Tensor): idx = int(torch.argmax(self.min_distances).item()) else: msg = f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}" raise TypeError(msg) return idx def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[int]: """Greedily form a coreset to minimize the maximum distance of a cluster. Args: selected_idxs: index of samples already selected. Defaults to an empty set. Returns: indices of samples selected to minimize distance to cluster centers """ if selected_idxs is None: selected_idxs = [] if self.embedding.ndim == 2: self.model.fit(self.embedding) self.features = self.model.transform(self.embedding) self.reset_distances() else: self.features = self.embedding.reshape(self.embedding.shape[0], -1) self.update_distances(cluster_centers=selected_idxs) selected_coreset_idxs: list[int] = [] idx = int(torch.randint(high=self.n_observations, size=(1,)).item()) for _ in track(range(self.coreset_size), description="Selecting Coreset Indices."): self.update_distances(cluster_centers=[idx]) idx = self.get_new_idx() if idx in selected_idxs: msg = "New indices should not be in selected indices." raise ValueError(msg) self.min_distances[idx] = 0 selected_coreset_idxs.append(idx) return selected_coreset_idxs def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor: """Select coreset from the embedding. Args: selected_idxs: index of samples already selected. Defaults to an empty set. Returns: Tensor: Output coreset Example: >>> embedding.shape torch.Size([219520, 1536]) >>> sampler = KCenterGreedy(...) >>> coreset = sampler.sample_coreset() >>> coreset.shape torch.Size([219, 1536]) """ idxs = self.select_coreset_idxs(selected_idxs) return self.embedding[idxs]