zhiqing0205
Add core libraries: anomalib, dinov2, open_clip_local
3de7bf6
"""SSPCAB: Self-Supervised Predictive Convolutional Attention Block for reconstruction-based models.
Paper https://arxiv.org/abs/2111.09099
"""
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import nn
from torch.nn import functional as F # noqa: N812
class AttentionModule(nn.Module):
"""Squeeze and excitation block that acts as the attention module in SSPCAB.
Args:
channels (int): Number of input channels.
reduction_ratio (int): Reduction ratio of the attention module.
"""
def __init__(self, in_channels: int, reduction_ratio: int = 8) -> None:
super().__init__()
out_channels = in_channels // reduction_ratio
self.fc1 = nn.Linear(in_channels, out_channels)
self.fc2 = nn.Linear(out_channels, in_channels)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Forward pass through the attention module."""
# reduce feature map to 1d vector through global average pooling
avg_pooled = inputs.mean(dim=(2, 3))
# squeeze and excite
act = self.fc1(avg_pooled)
act = F.relu(act)
act = self.fc2(act)
act = F.sigmoid(act)
# multiply with input
return inputs * act.view(act.shape[0], act.shape[1], 1, 1)
class SSPCAB(nn.Module):
"""SSPCAB block.
Args:
in_channels (int): Number of input channels.
kernel_size (int): Size of the receptive fields of the masked convolution kernel.
dilation (int): Dilation factor of the masked convolution kernel.
reduction_ratio (int): Reduction ratio of the attention module.
"""
def __init__(self, in_channels: int, kernel_size: int = 1, dilation: int = 1, reduction_ratio: int = 8) -> None:
super().__init__()
self.pad = kernel_size + dilation
self.crop = kernel_size + 2 * dilation + 1
self.masked_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv4 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.attention_module = AttentionModule(in_channels=in_channels, reduction_ratio=reduction_ratio)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Forward pass through the SSPCAB block."""
# compute masked convolution
padded = F.pad(inputs, (self.pad,) * 4)
masked_out = torch.zeros_like(inputs)
masked_out += self.masked_conv1(padded[..., : -self.crop, : -self.crop])
masked_out += self.masked_conv2(padded[..., : -self.crop, self.crop :])
masked_out += self.masked_conv3(padded[..., self.crop :, : -self.crop])
masked_out += self.masked_conv4(padded[..., self.crop :, self.crop :])
# apply channel attention module
return self.attention_module(masked_out)