File size: 4,493 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""k-Center Greedy Method.
Returns points that minimizes the maximum distance of any point to a center.
- https://arxiv.org/abs/1708.00489
"""
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import torch
from rich.progress import track
from torch.nn import functional as F # noqa: N812
from anomalib.models.components.dimensionality_reduction import SparseRandomProjection
class KCenterGreedy:
"""Implements k-center-greedy method.
Args:
embedding (torch.Tensor): Embedding vector extracted from a CNN
sampling_ratio (float): Ratio to choose coreset size from the embedding size.
Example:
>>> embedding.shape
torch.Size([219520, 1536])
>>> sampler = KCenterGreedy(embedding=embedding)
>>> sampled_idxs = sampler.select_coreset_idxs()
>>> coreset = embedding[sampled_idxs]
>>> coreset.shape
torch.Size([219, 1536])
"""
def __init__(self, embedding: torch.Tensor, sampling_ratio: float) -> None:
self.embedding = embedding
self.coreset_size = int(embedding.shape[0] * sampling_ratio)
self.model = SparseRandomProjection(eps=0.9)
self.features: torch.Tensor
self.min_distances: torch.Tensor = None
self.n_observations = self.embedding.shape[0]
def reset_distances(self) -> None:
"""Reset minimum distances."""
self.min_distances = None
def update_distances(self, cluster_centers: list[int]) -> None:
"""Update min distances given cluster centers.
Args:
cluster_centers (list[int]): indices of cluster centers
"""
if cluster_centers:
centers = self.features[cluster_centers]
distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1)
if self.min_distances is None:
self.min_distances = distance
else:
self.min_distances = torch.minimum(self.min_distances, distance)
def get_new_idx(self) -> int:
"""Get index value of a sample.
Based on minimum distance of the cluster
Returns:
int: Sample index
"""
if isinstance(self.min_distances, torch.Tensor):
idx = int(torch.argmax(self.min_distances).item())
else:
msg = f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}"
raise TypeError(msg)
return idx
def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[int]:
"""Greedily form a coreset to minimize the maximum distance of a cluster.
Args:
selected_idxs: index of samples already selected. Defaults to an empty set.
Returns:
indices of samples selected to minimize distance to cluster centers
"""
if selected_idxs is None:
selected_idxs = []
if self.embedding.ndim == 2:
self.model.fit(self.embedding)
self.features = self.model.transform(self.embedding)
self.reset_distances()
else:
self.features = self.embedding.reshape(self.embedding.shape[0], -1)
self.update_distances(cluster_centers=selected_idxs)
selected_coreset_idxs: list[int] = []
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
for _ in track(range(self.coreset_size), description="Selecting Coreset Indices."):
self.update_distances(cluster_centers=[idx])
idx = self.get_new_idx()
if idx in selected_idxs:
msg = "New indices should not be in selected indices."
raise ValueError(msg)
self.min_distances[idx] = 0
selected_coreset_idxs.append(idx)
return selected_coreset_idxs
def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor:
"""Select coreset from the embedding.
Args:
selected_idxs: index of samples already selected. Defaults to an empty set.
Returns:
Tensor: Output coreset
Example:
>>> embedding.shape
torch.Size([219520, 1536])
>>> sampler = KCenterGreedy(...)
>>> coreset = sampler.sample_coreset()
>>> coreset.shape
torch.Size([219, 1536])
"""
idxs = self.select_coreset_idxs(selected_idxs)
return self.embedding[idxs]
|