File size: 2,725 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""KMeans clustering algorithm implementation using PyTorch."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch


class KMeans:
    """Initialize the KMeans object.

    Args:
        n_clusters (int): The number of clusters to create.
        max_iter (int, optional)): The maximum number of iterations to run the algorithm. Defaults to 10.
    """

    def __init__(self, n_clusters: int, max_iter: int = 10) -> None:
        self.n_clusters = n_clusters
        self.max_iter = max_iter

    def fit(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Fit the K-means algorithm to the input data.

        Args:
            inputs (torch.Tensor): Input data of shape (batch_size, n_features).

        Returns:
            tuple: A tuple containing the labels of the input data with respect to the identified clusters
            and the cluster centers themselves. The labels have a shape of (batch_size,) and the
            cluster centers have a shape of (n_clusters, n_features).

        Raises:
            ValueError: If the number of clusters is less than or equal to 0.
        """
        batch_size, _ = inputs.shape

        # Initialize centroids randomly from the data points
        centroid_indices = torch.randint(0, batch_size, (self.n_clusters,))
        self.cluster_centers_ = inputs[centroid_indices]

        # Run the k-means algorithm for max_iter iterations
        for _ in range(self.max_iter):
            # Compute the distance between each data point and each centroid
            distances = torch.cdist(inputs, self.cluster_centers_)

            # Assign each data point to the closest centroid
            self.labels_ = torch.argmin(distances, dim=1)

            # Update the centroids to be the mean of the data points assigned to them
            for j in range(self.n_clusters):
                mask = self.labels_ == j
                if mask.any():
                    self.cluster_centers_[j] = inputs[mask].mean(dim=0)
        # this line returns labels and centoids of the results
        return self.labels_, self.cluster_centers_

    def predict(self, inputs: torch.Tensor) -> torch.Tensor:
        """Predict the labels of input data based on the fitted model.

        Args:
            inputs (torch.Tensor): Input data of shape (batch_size, n_features).

        Returns:
            torch.Tensor: The predicted labels of the input data with respect to the identified clusters.

        Raises:
            AttributeError: If the KMeans object has not been fitted to input data.
        """
        distances = torch.cdist(inputs, self.cluster_centers_)
        return torch.argmin(distances, dim=1)