# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import torch import torch.nn as nn from src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock from src.utils.tensors import trunc_normal_ class AttentivePooler(nn.Module): """Attentive Pooler""" def __init__( self, num_queries=1, embed_dim=768, num_heads=12, mlp_ratio=4.0, depth=1, norm_layer=nn.LayerNorm, init_std=0.02, qkv_bias=True, complete_block=True, use_activation_checkpointing=False, ): super().__init__() self.use_activation_checkpointing = use_activation_checkpointing self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) self.complete_block = complete_block if complete_block: self.cross_attention_block = CrossAttentionBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer ) else: self.cross_attention_block = CrossAttention(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias) self.blocks = None if depth > 1: self.blocks = nn.ModuleList( [ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=False, norm_layer=norm_layer, ) for i in range(depth - 1) ] ) self.init_std = init_std trunc_normal_(self.query_tokens, std=self.init_std) self.apply(self._init_weights) self._rescale_blocks() def _rescale_blocks(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) layer_id = 0 if self.blocks is not None: for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) if self.complete_block: rescale(self.cross_attention_block.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=self.init_std) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=self.init_std) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): if self.blocks is not None: for blk in self.blocks: if self.use_activation_checkpointing: x = torch.utils.checkpoint.checkpoint(blk, x, False, None, use_reentrant=False) else: x = blk(x) q = self.query_tokens.repeat(len(x), 1, 1) q = self.cross_attention_block(q, x) return q class AttentiveClassifier(nn.Module): """Attentive Classifier""" def __init__( self, embed_dim=768, num_heads=12, mlp_ratio=4.0, depth=1, norm_layer=nn.LayerNorm, init_std=0.02, qkv_bias=True, num_classes=1000, complete_block=True, use_activation_checkpointing=False, ): super().__init__() self.pooler = AttentivePooler( num_queries=1, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, depth=depth, norm_layer=norm_layer, init_std=init_std, qkv_bias=qkv_bias, complete_block=complete_block, use_activation_checkpointing=use_activation_checkpointing, ) self.linear = nn.Linear(embed_dim, num_classes, bias=True) def forward(self, x): x = self.pooler(x).squeeze(1) x = self.linear(x) return x