|
"""k-Center Greedy Method. |
|
|
|
Returns points that minimizes the maximum distance of any point to a center. |
|
- https://arxiv.org/abs/1708.00489 |
|
""" |
|
|
|
|
|
|
|
|
|
import torch |
|
from rich.progress import track |
|
from torch.nn import functional as F |
|
|
|
from anomalib.models.components.dimensionality_reduction import SparseRandomProjection |
|
|
|
|
|
class KCenterGreedy: |
|
"""Implements k-center-greedy method. |
|
|
|
Args: |
|
embedding (torch.Tensor): Embedding vector extracted from a CNN |
|
sampling_ratio (float): Ratio to choose coreset size from the embedding size. |
|
|
|
Example: |
|
>>> embedding.shape |
|
torch.Size([219520, 1536]) |
|
>>> sampler = KCenterGreedy(embedding=embedding) |
|
>>> sampled_idxs = sampler.select_coreset_idxs() |
|
>>> coreset = embedding[sampled_idxs] |
|
>>> coreset.shape |
|
torch.Size([219, 1536]) |
|
""" |
|
|
|
def __init__(self, embedding: torch.Tensor, sampling_ratio: float) -> None: |
|
self.embedding = embedding |
|
self.coreset_size = int(embedding.shape[0] * sampling_ratio) |
|
self.model = SparseRandomProjection(eps=0.9) |
|
|
|
self.features: torch.Tensor |
|
self.min_distances: torch.Tensor = None |
|
self.n_observations = self.embedding.shape[0] |
|
|
|
def reset_distances(self) -> None: |
|
"""Reset minimum distances.""" |
|
self.min_distances = None |
|
|
|
def update_distances(self, cluster_centers: list[int]) -> None: |
|
"""Update min distances given cluster centers. |
|
|
|
Args: |
|
cluster_centers (list[int]): indices of cluster centers |
|
""" |
|
if cluster_centers: |
|
centers = self.features[cluster_centers] |
|
|
|
distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1) |
|
|
|
if self.min_distances is None: |
|
self.min_distances = distance |
|
else: |
|
self.min_distances = torch.minimum(self.min_distances, distance) |
|
|
|
def get_new_idx(self) -> int: |
|
"""Get index value of a sample. |
|
|
|
Based on minimum distance of the cluster |
|
|
|
Returns: |
|
int: Sample index |
|
""" |
|
if isinstance(self.min_distances, torch.Tensor): |
|
idx = int(torch.argmax(self.min_distances).item()) |
|
else: |
|
msg = f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}" |
|
raise TypeError(msg) |
|
|
|
return idx |
|
|
|
def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[int]: |
|
"""Greedily form a coreset to minimize the maximum distance of a cluster. |
|
|
|
Args: |
|
selected_idxs: index of samples already selected. Defaults to an empty set. |
|
|
|
Returns: |
|
indices of samples selected to minimize distance to cluster centers |
|
""" |
|
if selected_idxs is None: |
|
selected_idxs = [] |
|
|
|
if self.embedding.ndim == 2: |
|
self.model.fit(self.embedding) |
|
self.features = self.model.transform(self.embedding) |
|
self.reset_distances() |
|
else: |
|
self.features = self.embedding.reshape(self.embedding.shape[0], -1) |
|
self.update_distances(cluster_centers=selected_idxs) |
|
|
|
selected_coreset_idxs: list[int] = [] |
|
idx = int(torch.randint(high=self.n_observations, size=(1,)).item()) |
|
for _ in track(range(self.coreset_size), description="Selecting Coreset Indices."): |
|
self.update_distances(cluster_centers=[idx]) |
|
idx = self.get_new_idx() |
|
if idx in selected_idxs: |
|
msg = "New indices should not be in selected indices." |
|
raise ValueError(msg) |
|
self.min_distances[idx] = 0 |
|
selected_coreset_idxs.append(idx) |
|
|
|
return selected_coreset_idxs |
|
|
|
def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor: |
|
"""Select coreset from the embedding. |
|
|
|
Args: |
|
selected_idxs: index of samples already selected. Defaults to an empty set. |
|
|
|
Returns: |
|
Tensor: Output coreset |
|
|
|
Example: |
|
>>> embedding.shape |
|
torch.Size([219520, 1536]) |
|
>>> sampler = KCenterGreedy(...) |
|
>>> coreset = sampler.sample_coreset() |
|
>>> coreset.shape |
|
torch.Size([219, 1536]) |
|
""" |
|
idxs = self.select_coreset_idxs(selected_idxs) |
|
return self.embedding[idxs] |
|
|