|
"""KMeans clustering algorithm implementation using PyTorch.""" |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
class KMeans: |
|
"""Initialize the KMeans object. |
|
|
|
Args: |
|
n_clusters (int): The number of clusters to create. |
|
max_iter (int, optional)): The maximum number of iterations to run the algorithm. Defaults to 10. |
|
""" |
|
|
|
def __init__(self, n_clusters: int, max_iter: int = 10) -> None: |
|
self.n_clusters = n_clusters |
|
self.max_iter = max_iter |
|
|
|
def fit(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Fit the K-means algorithm to the input data. |
|
|
|
Args: |
|
inputs (torch.Tensor): Input data of shape (batch_size, n_features). |
|
|
|
Returns: |
|
tuple: A tuple containing the labels of the input data with respect to the identified clusters |
|
and the cluster centers themselves. The labels have a shape of (batch_size,) and the |
|
cluster centers have a shape of (n_clusters, n_features). |
|
|
|
Raises: |
|
ValueError: If the number of clusters is less than or equal to 0. |
|
""" |
|
batch_size, _ = inputs.shape |
|
|
|
|
|
centroid_indices = torch.randint(0, batch_size, (self.n_clusters,)) |
|
self.cluster_centers_ = inputs[centroid_indices] |
|
|
|
|
|
for _ in range(self.max_iter): |
|
|
|
distances = torch.cdist(inputs, self.cluster_centers_) |
|
|
|
|
|
self.labels_ = torch.argmin(distances, dim=1) |
|
|
|
|
|
for j in range(self.n_clusters): |
|
mask = self.labels_ == j |
|
if mask.any(): |
|
self.cluster_centers_[j] = inputs[mask].mean(dim=0) |
|
|
|
return self.labels_, self.cluster_centers_ |
|
|
|
def predict(self, inputs: torch.Tensor) -> torch.Tensor: |
|
"""Predict the labels of input data based on the fitted model. |
|
|
|
Args: |
|
inputs (torch.Tensor): Input data of shape (batch_size, n_features). |
|
|
|
Returns: |
|
torch.Tensor: The predicted labels of the input data with respect to the identified clusters. |
|
|
|
Raises: |
|
AttributeError: If the KMeans object has not been fitted to input data. |
|
""" |
|
distances = torch.cdist(inputs, self.cluster_centers_) |
|
return torch.argmin(distances, dim=1) |
|
|