|
"""SSPCAB: Self-Supervised Predictive Convolutional Attention Block for reconstruction-based models. |
|
|
|
Paper https://arxiv.org/abs/2111.09099 |
|
""" |
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class AttentionModule(nn.Module): |
|
"""Squeeze and excitation block that acts as the attention module in SSPCAB. |
|
|
|
Args: |
|
channels (int): Number of input channels. |
|
reduction_ratio (int): Reduction ratio of the attention module. |
|
""" |
|
|
|
def __init__(self, in_channels: int, reduction_ratio: int = 8) -> None: |
|
super().__init__() |
|
|
|
out_channels = in_channels // reduction_ratio |
|
self.fc1 = nn.Linear(in_channels, out_channels) |
|
self.fc2 = nn.Linear(out_channels, in_channels) |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass through the attention module.""" |
|
|
|
avg_pooled = inputs.mean(dim=(2, 3)) |
|
|
|
|
|
act = self.fc1(avg_pooled) |
|
act = F.relu(act) |
|
act = self.fc2(act) |
|
act = F.sigmoid(act) |
|
|
|
|
|
return inputs * act.view(act.shape[0], act.shape[1], 1, 1) |
|
|
|
|
|
class SSPCAB(nn.Module): |
|
"""SSPCAB block. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
kernel_size (int): Size of the receptive fields of the masked convolution kernel. |
|
dilation (int): Dilation factor of the masked convolution kernel. |
|
reduction_ratio (int): Reduction ratio of the attention module. |
|
""" |
|
|
|
def __init__(self, in_channels: int, kernel_size: int = 1, dilation: int = 1, reduction_ratio: int = 8) -> None: |
|
super().__init__() |
|
|
|
self.pad = kernel_size + dilation |
|
self.crop = kernel_size + 2 * dilation + 1 |
|
|
|
self.masked_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) |
|
self.masked_conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) |
|
self.masked_conv3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) |
|
self.masked_conv4 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) |
|
|
|
self.attention_module = AttentionModule(in_channels=in_channels, reduction_ratio=reduction_ratio) |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass through the SSPCAB block.""" |
|
|
|
padded = F.pad(inputs, (self.pad,) * 4) |
|
masked_out = torch.zeros_like(inputs) |
|
masked_out += self.masked_conv1(padded[..., : -self.crop, : -self.crop]) |
|
masked_out += self.masked_conv2(padded[..., : -self.crop, self.crop :]) |
|
masked_out += self.masked_conv3(padded[..., self.crop :, : -self.crop]) |
|
masked_out += self.masked_conv4(padded[..., self.crop :, self.crop :]) |
|
|
|
|
|
return self.attention_module(masked_out) |
|
|