soccer-qa-4b / src /models /attentive_pooler.py
VarunKodathala's picture
Upload folder using huggingface_hub
0e37bb2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
from src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock
from src.utils.tensors import trunc_normal_
class AttentivePooler(nn.Module):
"""Attentive Pooler"""
def __init__(
self,
num_queries=1,
embed_dim=768,
num_heads=12,
mlp_ratio=4.0,
depth=1,
norm_layer=nn.LayerNorm,
init_std=0.02,
qkv_bias=True,
complete_block=True,
use_activation_checkpointing=False,
):
super().__init__()
self.use_activation_checkpointing = use_activation_checkpointing
self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
self.complete_block = complete_block
if complete_block:
self.cross_attention_block = CrossAttentionBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer
)
else:
self.cross_attention_block = CrossAttention(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.blocks = None
if depth > 1:
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=False,
norm_layer=norm_layer,
)
for i in range(depth - 1)
]
)
self.init_std = init_std
trunc_normal_(self.query_tokens, std=self.init_std)
self.apply(self._init_weights)
self._rescale_blocks()
def _rescale_blocks(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
layer_id = 0
if self.blocks is not None:
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
if self.complete_block:
rescale(self.cross_attention_block.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=self.init_std)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=self.init_std)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
if self.blocks is not None:
for blk in self.blocks:
if self.use_activation_checkpointing:
x = torch.utils.checkpoint.checkpoint(blk, x, False, None, use_reentrant=False)
else:
x = blk(x)
q = self.query_tokens.repeat(len(x), 1, 1)
q = self.cross_attention_block(q, x)
return q
class AttentiveClassifier(nn.Module):
"""Attentive Classifier"""
def __init__(
self,
embed_dim=768,
num_heads=12,
mlp_ratio=4.0,
depth=1,
norm_layer=nn.LayerNorm,
init_std=0.02,
qkv_bias=True,
num_classes=1000,
complete_block=True,
use_activation_checkpointing=False,
):
super().__init__()
self.pooler = AttentivePooler(
num_queries=1,
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
depth=depth,
norm_layer=norm_layer,
init_std=init_std,
qkv_bias=qkv_bias,
complete_block=complete_block,
use_activation_checkpointing=use_activation_checkpointing,
)
self.linear = nn.Linear(embed_dim, num_classes, bias=True)
def forward(self, x):
x = self.pooler(x).squeeze(1)
x = self.linear(x)
return x