File size: 3,142 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
"""SSPCAB: Self-Supervised Predictive Convolutional Attention Block for reconstruction-based models.
Paper https://arxiv.org/abs/2111.09099
"""
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import nn
from torch.nn import functional as F # noqa: N812
class AttentionModule(nn.Module):
"""Squeeze and excitation block that acts as the attention module in SSPCAB.
Args:
channels (int): Number of input channels.
reduction_ratio (int): Reduction ratio of the attention module.
"""
def __init__(self, in_channels: int, reduction_ratio: int = 8) -> None:
super().__init__()
out_channels = in_channels // reduction_ratio
self.fc1 = nn.Linear(in_channels, out_channels)
self.fc2 = nn.Linear(out_channels, in_channels)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Forward pass through the attention module."""
# reduce feature map to 1d vector through global average pooling
avg_pooled = inputs.mean(dim=(2, 3))
# squeeze and excite
act = self.fc1(avg_pooled)
act = F.relu(act)
act = self.fc2(act)
act = F.sigmoid(act)
# multiply with input
return inputs * act.view(act.shape[0], act.shape[1], 1, 1)
class SSPCAB(nn.Module):
"""SSPCAB block.
Args:
in_channels (int): Number of input channels.
kernel_size (int): Size of the receptive fields of the masked convolution kernel.
dilation (int): Dilation factor of the masked convolution kernel.
reduction_ratio (int): Reduction ratio of the attention module.
"""
def __init__(self, in_channels: int, kernel_size: int = 1, dilation: int = 1, reduction_ratio: int = 8) -> None:
super().__init__()
self.pad = kernel_size + dilation
self.crop = kernel_size + 2 * dilation + 1
self.masked_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv4 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.attention_module = AttentionModule(in_channels=in_channels, reduction_ratio=reduction_ratio)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Forward pass through the SSPCAB block."""
# compute masked convolution
padded = F.pad(inputs, (self.pad,) * 4)
masked_out = torch.zeros_like(inputs)
masked_out += self.masked_conv1(padded[..., : -self.crop, : -self.crop])
masked_out += self.masked_conv2(padded[..., : -self.crop, self.crop :])
masked_out += self.masked_conv3(padded[..., self.crop :, : -self.crop])
masked_out += self.masked_conv4(padded[..., self.crop :, self.crop :])
# apply channel attention module
return self.attention_module(masked_out)
|