zhiqing0205
Add core libraries: anomalib, dinov2, open_clip_local
3de7bf6
"""KMeans clustering algorithm implementation using PyTorch."""
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import torch
class KMeans:
"""Initialize the KMeans object.
Args:
n_clusters (int): The number of clusters to create.
max_iter (int, optional)): The maximum number of iterations to run the algorithm. Defaults to 10.
"""
def __init__(self, n_clusters: int, max_iter: int = 10) -> None:
self.n_clusters = n_clusters
self.max_iter = max_iter
def fit(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Fit the K-means algorithm to the input data.
Args:
inputs (torch.Tensor): Input data of shape (batch_size, n_features).
Returns:
tuple: A tuple containing the labels of the input data with respect to the identified clusters
and the cluster centers themselves. The labels have a shape of (batch_size,) and the
cluster centers have a shape of (n_clusters, n_features).
Raises:
ValueError: If the number of clusters is less than or equal to 0.
"""
batch_size, _ = inputs.shape
# Initialize centroids randomly from the data points
centroid_indices = torch.randint(0, batch_size, (self.n_clusters,))
self.cluster_centers_ = inputs[centroid_indices]
# Run the k-means algorithm for max_iter iterations
for _ in range(self.max_iter):
# Compute the distance between each data point and each centroid
distances = torch.cdist(inputs, self.cluster_centers_)
# Assign each data point to the closest centroid
self.labels_ = torch.argmin(distances, dim=1)
# Update the centroids to be the mean of the data points assigned to them
for j in range(self.n_clusters):
mask = self.labels_ == j
if mask.any():
self.cluster_centers_[j] = inputs[mask].mean(dim=0)
# this line returns labels and centoids of the results
return self.labels_, self.cluster_centers_
def predict(self, inputs: torch.Tensor) -> torch.Tensor:
"""Predict the labels of input data based on the fitted model.
Args:
inputs (torch.Tensor): Input data of shape (batch_size, n_features).
Returns:
torch.Tensor: The predicted labels of the input data with respect to the identified clusters.
Raises:
AttributeError: If the KMeans object has not been fitted to input data.
"""
distances = torch.cdist(inputs, self.cluster_centers_)
return torch.argmin(distances, dim=1)