|
"""Pytorch implementation of Gaussian Mixture Model.""" |
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
import torch |
|
from torch.distributions.multivariate_normal import MultivariateNormal |
|
from torch.nn.functional import one_hot |
|
|
|
from anomalib.models.components.base import DynamicBufferMixin |
|
from anomalib.models.components.cluster.kmeans import KMeans |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GaussianMixture(DynamicBufferMixin): |
|
"""Gaussian Mixture Model. |
|
|
|
Args: |
|
n_components (int): Number of components. |
|
n_iter (int): Maximum number of iterations to perform. |
|
Defaults to ``100``. |
|
tol (float): Convergence threshold. |
|
Defaults to ``1e-3``. |
|
|
|
Example: |
|
The following examples shows how to fit a Gaussian Mixture Model to some data and get the cluster means and |
|
predicted labels and log-likelihood scores of the data. |
|
|
|
.. code-block:: python |
|
|
|
>>> import torch |
|
>>> from anomalib.models.components.cluster import GaussianMixture |
|
>>> model = GaussianMixture(n_components=2) |
|
>>> data = torch.tensor( |
|
... [ |
|
... [2, 1], [2, 2], [2, 3], |
|
... [7, 5], [8, 5], [9, 5], |
|
... ] |
|
... ).float() |
|
>>> model.fit(data) |
|
>>> model.means # get the means of the gaussians |
|
tensor([[8., 5.], |
|
[2., 2.]]) |
|
>>> model.predict(data) # get the predicted cluster label of each sample |
|
tensor([1, 1, 1, 0, 0, 0]) |
|
>>> model.score_samples(data) # get the log-likelihood score of each sample |
|
tensor([3.8295, 4.5795, 3.8295, 3.8295, 4.5795, 3.8295]) |
|
""" |
|
|
|
def __init__(self, n_components: int, n_iter: int = 100, tol: float = 1e-3) -> None: |
|
super().__init__() |
|
self.n_components = n_components |
|
self.tol = tol |
|
self.n_iter = n_iter |
|
|
|
self.register_buffer("means", torch.Tensor()) |
|
self.register_buffer("covariances", torch.Tensor()) |
|
self.register_buffer("weights", torch.Tensor()) |
|
|
|
self.means: torch.Tensor |
|
self.covariances: torch.Tensor |
|
self.weights: torch.Tensor |
|
|
|
def fit(self, data: torch.Tensor) -> None: |
|
"""Fit the model to the data. |
|
|
|
Args: |
|
data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). |
|
""" |
|
self._initialize_parameters_kmeans(data) |
|
|
|
log_likelihood_old = 0 |
|
converged = False |
|
for _ in range(self.n_iter): |
|
|
|
log_likelihood_new, resp = self._e_step(data) |
|
|
|
self._m_step(data, resp) |
|
|
|
|
|
if torch.abs(log_likelihood_new - log_likelihood_old) < self.tol: |
|
converged = True |
|
break |
|
log_likelihood_old = log_likelihood_new |
|
|
|
if not converged: |
|
logger.warning( |
|
f"GMM did not converge after {self.n_iter} iterations. \ |
|
Consider increasing the number of iterations.", |
|
) |
|
|
|
def _initialize_parameters_kmeans(self, data: torch.Tensor) -> None: |
|
"""Initialize parameters with K-means. |
|
|
|
Args: |
|
data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). |
|
""" |
|
labels, _ = KMeans(n_clusters=self.n_components).fit(data) |
|
resp = one_hot(labels, num_classes=self.n_components).float() |
|
self._m_step(data, resp) |
|
|
|
def _e_step(self, data: torch.Tensor) -> torch.Tensor: |
|
"""Perform the E-step to estimate the responsibilities of the gaussians. |
|
|
|
Args: |
|
data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). |
|
|
|
Returns: |
|
Tensor: log probability of the data given the gaussians. |
|
Tensor: Tensor of shape (n_samples, n_components) containing the responsibilities. |
|
""" |
|
weighted_log_prob = self._estimate_weighted_log_prob(data) |
|
log_prob_norm = torch.logsumexp(weighted_log_prob, axis=1) |
|
log_resp = weighted_log_prob - torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) |
|
return torch.mean(log_prob_norm), torch.exp(log_resp) |
|
|
|
def _m_step(self, data: torch.Tensor, resp: torch.Tensor) -> None: |
|
"""Perform the M-step to update the parameters of the gaussians. |
|
|
|
Args: |
|
data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). |
|
resp (Tensor): Tensor of shape (n_samples, n_components) containing the responsibilities. |
|
""" |
|
cluster_counts = resp.sum(axis=0) |
|
self.weights = resp.mean(axis=0) |
|
self.means = (resp.T @ data) / cluster_counts[:, None] |
|
|
|
diff = data.unsqueeze(0) - self.means.unsqueeze(1) |
|
weighted_diff = diff * resp.T.unsqueeze(-1) |
|
covariances = torch.bmm(weighted_diff.transpose(-2, -1), diff) / cluster_counts.view(-1, 1, 1) |
|
|
|
self.covariances = covariances + torch.eye(data.shape[1], device=data.device) * 1e-6 |
|
|
|
def _estimate_weighted_log_prob(self, data: torch.Tensor) -> torch.Tensor: |
|
"""Estimate the log probability of the data given the gaussian parameters. |
|
|
|
Args: |
|
data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). |
|
|
|
Returns: |
|
Tensor: Tensor of shape (n_samples, n_components) containing the log-probabilities of each sample. |
|
""" |
|
log_prob = torch.stack( |
|
[ |
|
MultivariateNormal(self.means[comp], self.covariances[comp]).log_prob(data) |
|
for comp in range(self.n_components) |
|
], |
|
dim=1, |
|
) |
|
return log_prob + torch.log(self.weights) |
|
|
|
def score_samples(self, data: torch.Tensor) -> torch.Tensor: |
|
"""Assign a likelihood score to each sample in the data. |
|
|
|
Args: |
|
data (Tensor): Samples to assign scores to. Tensor of shape (n_samples, n_features). |
|
|
|
Returns: |
|
Tensor: Tensor of shape (n_samples,) containing the log-likelihood score of each sample. |
|
""" |
|
return torch.logsumexp(self._estimate_weighted_log_prob(data), dim=1) |
|
|
|
def predict(self, data: torch.Tensor) -> torch.Tensor: |
|
"""Predict the cluster labels of the data. |
|
|
|
Args: |
|
data (Tensor): Samples to assign to clusters. Tensor of shape (n_samples, n_features). |
|
|
|
Returns: |
|
Tensor: Tensor of shape (n_samples,) containing the predicted cluster label of each sample. |
|
""" |
|
_, resp = self._e_step(data) |
|
return torch.argmax(resp, axis=1) |
|
|