|
"""Gaussian Kernel Density Estimation.""" |
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
|
|
from anomalib.models.components.base import DynamicBufferMixin |
|
|
|
|
|
class GaussianKDE(DynamicBufferMixin): |
|
"""Gaussian Kernel Density Estimation. |
|
|
|
Args: |
|
dataset (Tensor | None, optional): Dataset on which to fit the KDE model. Defaults to None. |
|
""" |
|
|
|
def __init__(self, dataset: torch.Tensor | None = None) -> None: |
|
super().__init__() |
|
|
|
if dataset is not None: |
|
self.fit(dataset) |
|
|
|
self.register_buffer("bw_transform", torch.Tensor()) |
|
self.register_buffer("dataset", torch.Tensor()) |
|
self.register_buffer("norm", torch.Tensor()) |
|
|
|
self.bw_transform = torch.Tensor() |
|
self.dataset = torch.Tensor() |
|
self.norm = torch.Tensor() |
|
|
|
def forward(self, features: torch.Tensor) -> torch.Tensor: |
|
"""Get the KDE estimates from the feature map. |
|
|
|
Args: |
|
features (torch.Tensor): Feature map extracted from the CNN |
|
|
|
Returns: KDE Estimates |
|
""" |
|
features = torch.matmul(features, self.bw_transform) |
|
|
|
estimate = torch.zeros(features.shape[0]).to(features.device) |
|
for i in range(features.shape[0]): |
|
embedding = ((self.dataset - features[i]) ** 2).sum(dim=1) |
|
embedding = torch.exp(-embedding / 2) * self.norm |
|
estimate[i] = torch.mean(embedding) |
|
|
|
return estimate |
|
|
|
def fit(self, dataset: torch.Tensor) -> None: |
|
"""Fit a KDE model to the input dataset. |
|
|
|
Args: |
|
dataset (torch.Tensor): Input dataset. |
|
|
|
Returns: |
|
None |
|
""" |
|
num_samples, dimension = dataset.shape |
|
|
|
|
|
factor = num_samples ** (-1 / (dimension + 4)) |
|
|
|
cov_mat = self.cov(dataset.T) |
|
inv_cov_mat = torch.linalg.inv(cov_mat) |
|
inv_cov = inv_cov_mat / factor**2 |
|
|
|
|
|
bw_transform = torch.linalg.cholesky(inv_cov) |
|
dataset = torch.matmul(dataset, bw_transform) |
|
|
|
|
|
norm = torch.prod(torch.diag(bw_transform)) |
|
norm *= math.pow((2 * math.pi), (-dimension / 2)) |
|
|
|
self.bw_transform = bw_transform |
|
self.dataset = dataset |
|
self.norm = norm |
|
|
|
@staticmethod |
|
def cov(tensor: torch.Tensor) -> torch.Tensor: |
|
"""Calculate the unbiased covariance matrix. |
|
|
|
Args: |
|
tensor (torch.Tensor): Input tensor from which covariance matrix is computed. |
|
|
|
Returns: |
|
Output covariance matrix. |
|
""" |
|
mean = torch.mean(tensor, dim=1) |
|
tensor -= mean[:, None] |
|
return torch.matmul(tensor, tensor.T) / (tensor.size(1) - 1) |
|
|