|
from typing import List, Optional, Tuple, Union, Iterable, Any |
|
import torch, math |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from collections import OrderedDict |
|
from flash_attn import flash_attn_varlen_func |
|
from transformers.activations import ACT2FN |
|
import io, fire |
|
from torch.nn import functional as F |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers import PreTrainedModel, CLIPImageProcessor, PretrainedConfig |
|
from .configuration_baichuan import RQVAESIGLIPTransformerConfig, RQTransformerConfig, RQVAESiglipConfig, AttentionStackConfig, AttentionBlockConfig, SiglipConfig, SiglipTextConfig, SiglipVisionConfig |
|
from torch.utils.checkpoint import checkpoint |
|
import warnings |
|
from dataclasses import dataclass |
|
from torch.nn.init import _calculate_fan_in_and_fan_out |
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling |
|
from transformers.utils import ( |
|
ModelOutput, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
import torch.distributed as dist |
|
import numpy as np |
|
|
|
def top_k_logits(logits, k): |
|
v, ix = torch.topk(logits, k) |
|
out = logits.clone() |
|
out[out < v[:, [-1]]] = -float('Inf') |
|
|
|
return out |
|
|
|
|
|
def top_p_probs(probs, p): |
|
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) |
|
cum_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
sorted_idx_remove_cond = cum_probs >= p |
|
|
|
sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() |
|
sorted_idx_remove_cond[..., 0] = 0 |
|
|
|
indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) |
|
probs = probs.masked_fill(indices_to_remove, 0.0) |
|
norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True) |
|
|
|
return norm_probs |
|
|
|
|
|
def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None): |
|
"""Take a 2-dim tensor, apply softmax along each row, and sample from |
|
each multinomial distribution defined by the rows. |
|
|
|
Args: |
|
logits: 2-dim tensor of shape (n_samples, logit_dim) |
|
temperature (float): softmax temperature |
|
top_k (Optional[int]): if given, sample only using `top_k` logits |
|
top_p (Optional[float]): if given, sample only using `top_p` logits |
|
|
|
Returns: |
|
samples: 1-dim integer tensor of shape (n_samples,) |
|
""" |
|
|
|
logits = logits.to(dtype=torch.float32) |
|
logits = logits / temperature |
|
|
|
if top_k is not None: |
|
logits = top_k_logits(logits, top_k) |
|
|
|
if torch.sum(torch.isnan(logits)): |
|
print('WARNING... NaN observed') |
|
logits[torch.isnan(logits)] = -float('Inf') |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
if top_p is not None: |
|
probs = top_p_probs(probs, top_p) |
|
|
|
try: |
|
samples = torch.multinomial(probs, num_samples=1) |
|
except: |
|
raise RuntimeError |
|
|
|
return samples.view(-1) |
|
|
|
|
|
def nonlinearity(x): |
|
return F.silu(x, inplace=True) |
|
|
|
|
|
def Normalize(in_channels): |
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) |
|
|
|
|
|
|
|
""" PyTorch Siglip model.""" |
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
def norm_cdf(x): |
|
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
warnings.warn( |
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
"The distribution of values may be incorrect.", |
|
stacklevel=2, |
|
) |
|
|
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.0)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
|
|
|
|
def trunc_normal_tf_( |
|
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 |
|
) -> torch.Tensor: |
|
"""Fills the input Tensor with values drawn from a truncated |
|
normal distribution. The values are effectively drawn from the |
|
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` |
|
with values outside :math:`[a, b]` redrawn until they are within |
|
the bounds. The method used for generating the random values works |
|
best when :math:`a \\leq \text{mean} \\leq b`. |
|
|
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the |
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 |
|
and the result is subsquently scaled and shifted by the mean and std args. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
mean: the mean of the normal distribution |
|
std: the standard deviation of the normal distribution |
|
a: the minimum cutoff value |
|
b: the maximum cutoff value |
|
""" |
|
with torch.no_grad(): |
|
_trunc_normal_(tensor, 0, 1.0, a, b) |
|
tensor.mul_(std).add_(mean) |
|
|
|
|
|
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): |
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
|
if mode == "fan_in": |
|
denom = fan_in |
|
elif mode == "fan_out": |
|
denom = fan_out |
|
elif mode == "fan_avg": |
|
denom = (fan_in + fan_out) / 2 |
|
|
|
variance = scale / denom |
|
|
|
if distribution == "truncated_normal": |
|
|
|
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) |
|
elif distribution == "normal": |
|
with torch.no_grad(): |
|
tensor.normal_(std=math.sqrt(variance)) |
|
elif distribution == "uniform": |
|
bound = math.sqrt(3 * variance) |
|
with torch.no_grad(): |
|
tensor.uniform_(-bound, bound) |
|
else: |
|
raise ValueError(f"invalid distribution {distribution}") |
|
|
|
|
|
def lecun_normal_(tensor): |
|
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") |
|
|
|
|
|
def default_flax_embed_init(tensor): |
|
variance_scaling_(tensor, mode="fan_in", distribution="normal") |
|
|
|
|
|
@dataclass |
|
|
|
class SiglipVisionModelOutput(ModelOutput): |
|
""" |
|
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. |
|
|
|
Args: |
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): |
|
The image embeddings obtained by applying the projection layer to the pooler_output. |
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
image_embeds: Optional[torch.FloatTensor] = None |
|
last_hidden_state: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
|
|
class SiglipTextModelOutput(ModelOutput): |
|
""" |
|
Base class for text model's outputs that also contains a pooling of the last hidden states. |
|
|
|
Args: |
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): |
|
The text embeddings obtained by applying the projection layer to the pooler_output. |
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
text_embeds: Optional[torch.FloatTensor] = None |
|
last_hidden_state: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
|
|
class SiglipOutput(ModelOutput): |
|
""" |
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): |
|
Contrastive loss for image-text similarity. |
|
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): |
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text |
|
similarity scores. |
|
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): |
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image |
|
similarity scores. |
|
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): |
|
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. |
|
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): |
|
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. |
|
text_model_output(`BaseModelOutputWithPooling`): |
|
The output of the [`SiglipTextModel`]. |
|
vision_model_output(`BaseModelOutputWithPooling`): |
|
The output of the [`SiglipVisionModel`]. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits_per_image: torch.FloatTensor = None |
|
logits_per_text: torch.FloatTensor = None |
|
text_embeds: torch.FloatTensor = None |
|
image_embeds: torch.FloatTensor = None |
|
text_model_output: BaseModelOutputWithPooling = None |
|
vision_model_output: BaseModelOutputWithPooling = None |
|
|
|
def to_tuple(self) -> Tuple[Any]: |
|
return tuple( |
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() |
|
for k in self.keys() |
|
) |
|
|
|
|
|
class SiglipVisionEmbeddings(nn.Module): |
|
def __init__(self, config: SiglipVisionConfig): |
|
super().__init__() |
|
self.config = config |
|
self.embed_dim = config.hidden_size |
|
self.image_size = config.image_size |
|
self.patch_size = config.patch_size |
|
|
|
self.patch_embedding = nn.Conv2d( |
|
in_channels=config.num_channels, |
|
out_channels=self.embed_dim, |
|
kernel_size=self.patch_size, |
|
stride=self.patch_size, |
|
padding="valid", |
|
) |
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2 |
|
self.num_positions = self.num_patches |
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) |
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) |
|
|
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
|
pixel_values = pixel_values.to(self.patch_embedding.weight.device) |
|
patch_embeds = self.patch_embedding(pixel_values) |
|
embeddings = patch_embeds.flatten(2).transpose(1, 2) |
|
|
|
embeddings = embeddings + self.position_embedding(self.position_ids) |
|
return embeddings |
|
|
|
|
|
|
|
class SiglipTextEmbeddings(nn.Module): |
|
def __init__(self, config: SiglipTextConfig): |
|
super().__init__() |
|
embed_dim = config.hidden_size |
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) |
|
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) |
|
|
|
|
|
self.register_buffer( |
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
) -> torch.Tensor: |
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] |
|
|
|
if position_ids is None: |
|
position_ids = self.position_ids[:, :seq_length] |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.token_embedding(input_ids) |
|
|
|
position_embeddings = self.position_embedding(position_ids) |
|
embeddings = inputs_embeds + position_embeddings |
|
|
|
return embeddings |
|
|
|
|
|
class SiglipAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.embed_dim = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.embed_dim // self.num_heads |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
|
f" {self.num_heads})." |
|
) |
|
self.scale = self.head_dim**-0.5 |
|
self.dropout = config.attention_dropout |
|
|
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
batch_size, q_len, _ = hidden_states.size() |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
k_v_seq_len = key_states.shape[-2] |
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale |
|
|
|
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" |
|
) |
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
class SiglipMLP(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.activation_fn = ACT2FN[config.hidden_act] |
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.fc1(hidden_states) |
|
hidden_states = self.activation_fn(hidden_states) |
|
hidden_states = self.fc2(hidden_states) |
|
return hidden_states |
|
|
|
|
|
|
|
class SiglipEncoderLayer(nn.Module): |
|
def __init__(self, config: SiglipConfig): |
|
super().__init__() |
|
self.embed_dim = config.hidden_size |
|
self.self_attn = SiglipAttention(config) |
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
self.mlp = SiglipMLP(config) |
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.FloatTensor]: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): |
|
Input to the layer of shape `(batch, seq_len, embed_dim)`. |
|
attention_mask (`torch.FloatTensor`): |
|
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. |
|
output_attentions (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
""" |
|
residual = hidden_states |
|
|
|
hidden_states = self.layer_norm1(hidden_states) |
|
hidden_states, attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.layer_norm2(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class SiglipPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = SiglipConfig |
|
base_model_prefix = "siglip" |
|
supports_gradient_checkpointing = True |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, SiglipVisionEmbeddings): |
|
width = ( |
|
self.config.vision_config.hidden_size |
|
if isinstance(self.config, SiglipConfig) |
|
else self.config.hidden_size |
|
) |
|
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) |
|
elif isinstance(module, nn.Embedding): |
|
default_flax_embed_init(module.weight) |
|
elif isinstance(module, SiglipAttention): |
|
nn.init.xavier_uniform_(module.q_proj.weight) |
|
nn.init.xavier_uniform_(module.k_proj.weight) |
|
nn.init.xavier_uniform_(module.v_proj.weight) |
|
nn.init.xavier_uniform_(module.out_proj.weight) |
|
nn.init.zeros_(module.q_proj.bias) |
|
nn.init.zeros_(module.k_proj.bias) |
|
nn.init.zeros_(module.v_proj.bias) |
|
nn.init.zeros_(module.out_proj.bias) |
|
elif isinstance(module, SiglipMLP): |
|
nn.init.xavier_uniform_(module.fc1.weight) |
|
nn.init.xavier_uniform_(module.fc2.weight) |
|
nn.init.normal_(module.fc1.bias, std=1e-6) |
|
nn.init.normal_(module.fc2.bias, std=1e-6) |
|
elif isinstance(module, SiglipMultiheadAttentionPoolingHead): |
|
nn.init.xavier_uniform_(module.probe.data) |
|
nn.init.xavier_uniform_(module.attention.in_proj_weight.data) |
|
nn.init.zeros_(module.attention.in_proj_bias.data) |
|
elif isinstance(module, SiglipModel): |
|
logit_scale_init = torch.log(torch.tensor(1.0)) |
|
module.logit_scale.data.fill_(logit_scale_init) |
|
module.logit_bias.data.zero_() |
|
elif isinstance(module, (nn.Linear, nn.Conv2d)): |
|
lecun_normal_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
SIGLIP_START_DOCSTRING = r""" |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
SIGLIP_TEXT_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.max_position_embeddings - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
SIGLIP_VISION_INPUTS_DOCSTRING = r""" |
|
Args: |
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): |
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using |
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
SIGLIP_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.max_position_embeddings - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): |
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using |
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
return_loss (`bool`, *optional*): |
|
Whether or not to return the contrastive loss. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
|
|
class SiglipEncoder(nn.Module): |
|
""" |
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a |
|
[`SiglipEncoderLayer`]. |
|
|
|
Args: |
|
config: SiglipConfig |
|
""" |
|
|
|
def __init__(self, config: SiglipConfig): |
|
super().__init__() |
|
self.config = config |
|
self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward( |
|
self, |
|
inputs_embeds, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutput]: |
|
r""" |
|
Args: |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
|
than the model's internal embedding lookup matrix. |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
encoder_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
|
|
hidden_states = inputs_embeds |
|
for encoder_layer in self.layers: |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
encoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
output_attentions, |
|
) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions |
|
) |
|
|
|
|
|
class SiglipTextTransformer(nn.Module): |
|
def __init__(self, config: SiglipTextConfig): |
|
super().__init__() |
|
self.config = config |
|
embed_dim = config.hidden_size |
|
self.embeddings = SiglipTextEmbeddings(config) |
|
self.encoder = SiglipEncoder(config) |
|
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
|
|
|
self.head = nn.Linear(embed_dim, embed_dim) |
|
|
|
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
r""" |
|
Returns: |
|
|
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is None: |
|
raise ValueError("You have to specify input_ids") |
|
|
|
input_shape = input_ids.size() |
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
|
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoder_outputs = self.encoder( |
|
inputs_embeds=hidden_states, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_state = encoder_outputs[0] |
|
last_hidden_state = self.final_layer_norm(last_hidden_state) |
|
|
|
|
|
pooled_output = last_hidden_state[:, -1, :] |
|
pooled_output = self.head(pooled_output) |
|
|
|
if not return_dict: |
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
"""The text model from SigLIP without any head or projection on top.""", |
|
SIGLIP_START_DOCSTRING, |
|
) |
|
class SiglipTextModel(SiglipPreTrainedModel): |
|
config_class = SiglipTextConfig |
|
|
|
_no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] |
|
|
|
def __init__(self, config: SiglipTextConfig): |
|
super().__init__(config) |
|
self.text_model = SiglipTextTransformer(config) |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
return self.text_model.embeddings.token_embedding |
|
|
|
def set_input_embeddings(self, value): |
|
self.text_model.embeddings.token_embedding = value |
|
|
|
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
r""" |
|
Returns: |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, SiglipTextModel |
|
|
|
>>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") |
|
|
|
>>> # important: make sure to set padding="max_length" as that's how the model was trained |
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") |
|
|
|
>>> outputs = model(**inputs) |
|
>>> last_hidden_state = outputs.last_hidden_state |
|
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states |
|
```""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
return self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
class SiglipVisionTransformer(nn.Module): |
|
def __init__(self, config: SiglipVisionConfig): |
|
super().__init__() |
|
self.config = config |
|
embed_dim = config.hidden_size |
|
|
|
self.embeddings = SiglipVisionEmbeddings(config) |
|
self.encoder = SiglipEncoder(config) |
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
|
self.head = SiglipMultiheadAttentionPoolingHead(config) |
|
|
|
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) |
|
def forward( |
|
self, |
|
pixel_values, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
r""" |
|
Returns: |
|
|
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
hidden_states = self.embeddings(pixel_values) |
|
|
|
encoder_outputs = self.encoder( |
|
inputs_embeds=hidden_states, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_state = encoder_outputs[0] |
|
last_hidden_state = self.post_layernorm(last_hidden_state) |
|
|
|
pooled_output = self.head(last_hidden_state) |
|
|
|
if not return_dict: |
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class SiglipMultiheadAttentionPoolingHead(nn.Module): |
|
"""Multihead Attention Pooling.""" |
|
|
|
def __init__(self, config: SiglipVisionConfig): |
|
super().__init__() |
|
|
|
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) |
|
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) |
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.mlp = SiglipMLP(config) |
|
|
|
def forward(self, hidden_state): |
|
batch_size = hidden_state.shape[0] |
|
probe = self.probe.repeat(batch_size, 1, 1) |
|
|
|
hidden_state = self.attention(probe, hidden_state, hidden_state)[0] |
|
|
|
residual = hidden_state |
|
hidden_state = self.layernorm(hidden_state) |
|
hidden_state = residual + self.mlp(hidden_state) |
|
|
|
return hidden_state[:, 0] |
|
|
|
|
|
@add_start_docstrings( |
|
"""The vision model from SigLIP without any head or projection on top.""", |
|
SIGLIP_START_DOCSTRING, |
|
) |
|
class SiglipVisionModel(SiglipPreTrainedModel): |
|
config_class = SiglipVisionConfig |
|
main_input_name = "pixel_values" |
|
|
|
def __init__(self, config: SiglipVisionConfig): |
|
super().__init__(config) |
|
|
|
self.vision_model = SiglipVisionTransformer(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
return self.vision_model.embeddings.patch_embedding |
|
|
|
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) |
|
def forward( |
|
self, |
|
pixel_values, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
r""" |
|
Returns: |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, SiglipVisionModel |
|
|
|
>>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") |
|
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") |
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(images=image, return_tensors="pt") |
|
|
|
>>> outputs = model(**inputs) |
|
>>> last_hidden_state = outputs.last_hidden_state |
|
>>> pooled_output = outputs.pooler_output # pooled features |
|
```""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
return self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
@add_start_docstrings(SIGLIP_START_DOCSTRING) |
|
class SiglipModel(SiglipPreTrainedModel): |
|
config_class = SiglipConfig |
|
|
|
def __init__(self, config: SiglipConfig): |
|
super().__init__(config) |
|
|
|
if not isinstance(config.text_config, SiglipTextConfig): |
|
raise ValueError( |
|
"config.text_config is expected to be of type SiglipTextConfig but is of type" |
|
f" {type(config.text_config)}." |
|
) |
|
|
|
if not isinstance(config.vision_config, SiglipVisionConfig): |
|
raise ValueError( |
|
"config.vision_config is expected to be of type SiglipVisionConfig but is of type" |
|
f" {type(config.vision_config)}." |
|
) |
|
|
|
text_config = config.text_config |
|
vision_config = config.vision_config |
|
|
|
self.text_model = SiglipTextTransformer(text_config) |
|
self.vision_model = SiglipVisionTransformer(vision_config) |
|
|
|
self.logit_scale = nn.Parameter(torch.randn(1)) |
|
self.logit_bias = nn.Parameter(torch.randn(1)) |
|
|
|
|
|
self.post_init() |
|
|
|
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) |
|
def get_text_features( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> torch.FloatTensor: |
|
r""" |
|
Returns: |
|
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by |
|
applying the projection layer to the pooled output of [`SiglipTextModel`]. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, AutoModel |
|
>>> import torch |
|
|
|
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") |
|
|
|
>>> # important: make sure to set padding="max_length" as that's how the model was trained |
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") |
|
>>> with torch.no_grad(): |
|
... text_features = model.get_text_features(**inputs) |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
text_outputs = self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = text_outputs[1] |
|
|
|
return pooled_output |
|
|
|
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) |
|
def get_image_features( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> torch.FloatTensor: |
|
r""" |
|
Returns: |
|
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by |
|
applying the projection layer to the pooled output of [`SiglipVisionModel`]. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, AutoModel |
|
>>> import torch |
|
|
|
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") |
|
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") |
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(images=image, return_tensors="pt") |
|
|
|
>>> with torch.no_grad(): |
|
... image_features = model.get_image_features(**inputs) |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = vision_outputs[1] |
|
|
|
return pooled_output |
|
|
|
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
return_loss: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, SiglipOutput]: |
|
r""" |
|
Returns: |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, AutoModel |
|
>>> import torch |
|
|
|
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") |
|
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") |
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] |
|
>>> inputs = processor(text=texts, images=image, return_tensors="pt") |
|
|
|
>>> with torch.no_grad(): |
|
... outputs = model(**inputs) |
|
|
|
>>> logits_per_image = outputs.logits_per_image |
|
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities |
|
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") |
|
31.9% that image 0 is 'a photo of 2 cats' |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
text_outputs = self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
image_embeds = vision_outputs[1] |
|
text_embeds = text_outputs[1] |
|
|
|
|
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) |
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
|
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias |
|
logits_per_image = logits_per_text.t() |
|
|
|
loss = None |
|
if return_loss: |
|
raise NotImplementedError("SigLIP loss to be implemented") |
|
|
|
if not return_dict: |
|
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SiglipOutput( |
|
loss=loss, |
|
logits_per_image=logits_per_image, |
|
logits_per_text=logits_per_text, |
|
text_embeds=text_embeds, |
|
image_embeds=image_embeds, |
|
text_model_output=text_outputs, |
|
vision_model_output=vision_outputs, |
|
) |
|
|
|
|
|
class Upsample(nn.Module): |
|
def __init__(self, in_channels, with_conv): |
|
super().__init__() |
|
self.with_conv = with_conv |
|
if self.with_conv: |
|
self.conv = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
|
|
def forward(self, x): |
|
x = torch.nn.functional.interpolate(x.to(torch.float32), scale_factor=2.0, mode="nearest").to(torch.bfloat16) |
|
if self.with_conv: |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, |
|
dropout, temb_channels=512): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
out_channels = in_channels if out_channels is None else out_channels |
|
self.out_channels = out_channels |
|
self.use_conv_shortcut = conv_shortcut |
|
self.checkpointing = False |
|
|
|
self.norm1 = Normalize(in_channels) |
|
self.conv1 = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
if temb_channels > 0: |
|
self.temb_proj = torch.nn.Linear(temb_channels, |
|
out_channels) |
|
self.norm2 = Normalize(out_channels) |
|
self.dropout = torch.nn.Dropout(dropout, inplace=True) |
|
self.conv2 = torch.nn.Conv2d(out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
else: |
|
self.nin_shortcut = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
|
|
def _forward(self, x, temb): |
|
h = x |
|
h = self.norm1(h) |
|
h = nonlinearity(h) |
|
h = self.conv1(h) |
|
|
|
if temb is not None: |
|
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] |
|
|
|
h = self.norm2(h) |
|
h = nonlinearity(h) |
|
h = self.dropout(h) |
|
h = self.conv2(h) |
|
|
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
x = self.conv_shortcut(x) |
|
else: |
|
x = self.nin_shortcut(x) |
|
|
|
return x+h |
|
|
|
def forward(self, x, temb): |
|
if self.checkpointing and self.training: |
|
out = checkpoint(self._forward, x, temb) |
|
else: |
|
out = self._forward(x, temb) |
|
return out |
|
|
|
|
|
class PostQuantResnetBlock(nn.Module): |
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
out_channels = in_channels if out_channels is None else out_channels |
|
self.out_channels = out_channels |
|
self.use_conv_shortcut = conv_shortcut |
|
self.checkpointing = False |
|
|
|
self.norm1 = Normalize(in_channels) |
|
self.conv1 = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
self.norm2 = Normalize(out_channels) |
|
self.dropout = torch.nn.Dropout(dropout, inplace=True) |
|
self.conv2 = torch.nn.Conv2d(out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
else: |
|
self.nin_shortcut = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
|
|
def _forward(self, x): |
|
h = x |
|
h = self.norm1(h) |
|
h = nonlinearity(h) |
|
h = self.conv1(h) |
|
|
|
h = self.norm2(h) |
|
h = nonlinearity(h) |
|
h = self.dropout(h) |
|
h = self.conv2(h) |
|
|
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
x = self.conv_shortcut(x) |
|
else: |
|
x = self.nin_shortcut(x) |
|
|
|
return x+h |
|
|
|
def forward(self, x): |
|
if self.checkpointing and self.training: |
|
out = checkpoint(self._forward, x) |
|
else: |
|
out = self._forward(x) |
|
return out |
|
|
|
|
|
class ProjectResnetBlock(nn.Module): |
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
out_channels = in_channels if out_channels is None else out_channels |
|
self.out_channels = out_channels |
|
self.use_conv_shortcut = conv_shortcut |
|
self.checkpointing = False |
|
|
|
self.norm1 = Normalize(in_channels) |
|
self.conv1 = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
self.norm2 = Normalize(out_channels) |
|
self.dropout = torch.nn.Dropout(dropout, inplace=True) |
|
self.conv2 = torch.nn.Conv2d(out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
else: |
|
self.nin_shortcut = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
|
|
def _forward(self, x): |
|
h = x |
|
h = self.norm1(h) |
|
h = nonlinearity(h) |
|
h = self.conv1(h) |
|
|
|
h = self.norm2(h) |
|
h = nonlinearity(h) |
|
h = self.dropout(h) |
|
h = self.conv2(h) |
|
|
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
x = self.conv_shortcut(x) |
|
else: |
|
x = self.nin_shortcut(x) |
|
|
|
return x+h |
|
|
|
def forward(self, x): |
|
if self.checkpointing and self.training: |
|
out = checkpoint(self._forward, x) |
|
else: |
|
out = self._forward(x) |
|
return out |
|
|
|
|
|
class AttnBlock(nn.Module): |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
self.k = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
self.v = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
self.proj_out = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
|
|
|
|
def forward(self, x): |
|
h_ = x |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
b,c,h,w = q.shape |
|
q = q.reshape(b,c,h*w) |
|
q = q.permute(0,2,1) |
|
k = k.reshape(b,c,h*w) |
|
w_ = torch.bmm(q,k) |
|
w_ = w_ * (int(c)**(-0.5)) |
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
v = v.reshape(b,c,h*w) |
|
w_ = w_.permute(0,2,1) |
|
h_ = torch.bmm(v,w_) |
|
h_ = h_.reshape(b,c,h,w) |
|
|
|
h_ = self.proj_out(h_) |
|
|
|
return x+h_ |
|
|
|
class VQEmbedding(nn.Embedding): |
|
"""VQ embedding module with ema update.""" |
|
|
|
def __init__(self, n_embed, embed_dim, ema=False, decay=0.99, restart_unused_codes=True, eps=1e-5): |
|
super().__init__(n_embed + 1, embed_dim, padding_idx=n_embed) |
|
|
|
self.ema = ema |
|
self.decay = decay |
|
self.eps = eps |
|
self.restart_unused_codes = restart_unused_codes |
|
self.n_embed = n_embed |
|
|
|
@torch.no_grad() |
|
def compute_distances(self, inputs): |
|
codebook_t = self.weight[:-1, :].t() |
|
|
|
(embed_dim, _) = codebook_t.shape |
|
inputs_shape = inputs.shape |
|
assert inputs_shape[-1] == embed_dim |
|
inputs_flat = inputs.reshape(-1, embed_dim) |
|
|
|
inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True) |
|
codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True) |
|
distances = torch.addmm( |
|
inputs_norm_sq + codebook_t_norm_sq, |
|
inputs_flat, |
|
codebook_t, |
|
alpha=-2.0, |
|
) |
|
distances = distances.reshape(*inputs_shape[:-1], -1) |
|
return distances |
|
|
|
@torch.no_grad() |
|
def find_nearest_embedding(self, inputs): |
|
distances = self.compute_distances(inputs) |
|
embed_idxs = distances.argmin(dim=-1) |
|
|
|
return embed_idxs |
|
|
|
def forward(self, inputs): |
|
embed_idxs = self.find_nearest_embedding(inputs) |
|
embeds = self.embed(embed_idxs) |
|
return embeds, embed_idxs |
|
|
|
def embed(self, idxs): |
|
embeds = super().forward(idxs) |
|
return embeds |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, |
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, |
|
resolution, decoder_in_channels, give_pre_end=False, **ignorekwargs): |
|
super().__init__() |
|
self.ch = ch |
|
self.temb_ch = 0 |
|
self.num_resolutions = len(ch_mult) |
|
self.num_res_blocks = num_res_blocks |
|
self.resolution = resolution |
|
self.in_channels = in_channels |
|
self.give_pre_end = give_pre_end |
|
|
|
in_ch_mult = (1,)+tuple(ch_mult) |
|
block_in = ch*ch_mult[self.num_resolutions-1] |
|
curr_res = resolution // 2**(self.num_resolutions-1) |
|
self.z_shape = (1, decoder_in_channels, curr_res, curr_res) |
|
|
|
self.conv_in = torch.nn.Conv2d(decoder_in_channels, |
|
block_in, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
|
|
self.mid = nn.Module() |
|
self.mid.block_1 = ResnetBlock(in_channels=block_in, |
|
out_channels=block_in, |
|
temb_channels=self.temb_ch, |
|
dropout=dropout) |
|
self.mid.attn_1 = AttnBlock(block_in) |
|
self.mid.block_2 = ResnetBlock(in_channels=block_in, |
|
out_channels=block_in, |
|
temb_channels=self.temb_ch, |
|
dropout=dropout) |
|
|
|
self.up = nn.ModuleList() |
|
for i_level in reversed(range(self.num_resolutions)): |
|
block = nn.ModuleList() |
|
attn = nn.ModuleList() |
|
block_out = ch*ch_mult[i_level] |
|
for i_block in range(self.num_res_blocks+1): |
|
block.append(ResnetBlock(in_channels=block_in, |
|
out_channels=block_out, |
|
temb_channels=self.temb_ch, |
|
dropout=dropout)) |
|
block_in = block_out |
|
if curr_res in attn_resolutions: |
|
attn.append(AttnBlock(block_in)) |
|
up = nn.Module() |
|
up.block = block |
|
up.attn = attn |
|
if i_level != 0: |
|
up.upsample = Upsample(block_in, resamp_with_conv) |
|
curr_res = curr_res * 2 |
|
self.up.insert(0, up) |
|
|
|
self.norm_out = Normalize(block_in) |
|
self.conv_out = torch.nn.Conv2d(block_in, |
|
out_ch, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
|
|
def forward(self, z): |
|
self.last_z_shape = z.shape |
|
|
|
temb = None |
|
|
|
h = self.conv_in(z) |
|
|
|
h = self.mid.block_1(h, temb) |
|
h = self.mid.attn_1(h) |
|
h = self.mid.block_2(h, temb) |
|
|
|
for i_level in reversed(range(self.num_resolutions)): |
|
for i_block in range(self.num_res_blocks+1): |
|
h = self.up[i_level].block[i_block](h, temb) |
|
if len(self.up[i_level].attn) > 0: |
|
h = self.up[i_level].attn[i_block](h) |
|
if i_level != 0: |
|
h = self.up[i_level].upsample(h) |
|
|
|
if self.give_pre_end: |
|
return h |
|
|
|
h = self.norm_out(h) |
|
h = nonlinearity(h) |
|
h = self.conv_out(h) |
|
return h |
|
|
|
|
|
class RQBottleneck(nn.Module): |
|
""" |
|
Quantization bottleneck via Residual Quantization. |
|
|
|
Arguments: |
|
latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D) |
|
code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d) |
|
n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook) |
|
If isinstance(n_embed, int), the sizes of all codebooks are same. |
|
shared_codebook (bool): If True, codebooks are shared in all location. If False, |
|
uses separate codebooks along the ``depth'' dimension. (default: False) |
|
restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch |
|
as the new embedding of unused codes in training. (default: True) |
|
""" |
|
|
|
def __init__(self, |
|
latent_shape, |
|
code_shape, |
|
n_embed, |
|
decay=0.99, |
|
shared_codebook=False, |
|
restart_unused_codes=True, |
|
commitment_loss='cumsum' |
|
): |
|
super().__init__() |
|
|
|
if not len(code_shape) == len(latent_shape) == 3: |
|
raise ValueError("incompatible code shape or latent shape") |
|
if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]): |
|
raise ValueError("incompatible code shape or latent shape") |
|
|
|
embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2] |
|
|
|
self.latent_shape = torch.Size(latent_shape) |
|
self.code_shape = torch.Size(code_shape) |
|
self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))]) |
|
|
|
self.shared_codebook = shared_codebook |
|
if self.shared_codebook: |
|
if isinstance(n_embed, Iterable) or isinstance(decay, Iterable): |
|
raise ValueError("Shared codebooks are incompatible \ |
|
with list types of momentums or sizes: Change it into int") |
|
|
|
self.restart_unused_codes = restart_unused_codes |
|
self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])] |
|
self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])] |
|
assert len(self.n_embed) == self.code_shape[-1] |
|
assert len(self.decay) == self.code_shape[-1] |
|
|
|
if self.shared_codebook: |
|
codebook0 = VQEmbedding(self.n_embed[0], |
|
embed_dim, |
|
decay=self.decay[0], |
|
restart_unused_codes=restart_unused_codes, |
|
) |
|
self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])]) |
|
else: |
|
codebooks = [VQEmbedding(self.n_embed[idx], |
|
embed_dim, |
|
decay=self.decay[idx], |
|
restart_unused_codes=restart_unused_codes, |
|
) for idx in range(self.code_shape[-1])] |
|
self.codebooks = nn.ModuleList(codebooks) |
|
|
|
self.commitment_loss = commitment_loss |
|
|
|
def to_code_shape(self, x): |
|
(B, H, W, D) = x.shape |
|
(rH, rW, _) = self.shape_divisor |
|
|
|
x = x.reshape(B, H//rH, rH, W//rW, rW, D) |
|
x = x.permute(0, 1, 3, 2, 4, 5) |
|
x = x.reshape(B, H//rH, W//rW, -1) |
|
|
|
return x |
|
|
|
def to_latent_shape(self, x): |
|
(B, h, w, _) = x.shape |
|
(_, _, D) = self.latent_shape |
|
(rH, rW, _) = self.shape_divisor |
|
|
|
x = x.reshape(B, h, w, rH, rW, D) |
|
x = x.permute(0, 1, 3, 2, 4, 5) |
|
x = x.reshape(B, h*rH, w*rW, D) |
|
|
|
return x |
|
|
|
def quantize(self, x): |
|
r""" |
|
Return list of quantized features and the selected codewords by the residual quantization. |
|
The code is selected by the residuals between x and quantized features by the previous codebooks. |
|
|
|
Arguments: |
|
x (Tensor): bottleneck feature maps to quantize. |
|
|
|
Returns: |
|
quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks. |
|
codes (LongTensor): codewords index, corresponding to quants. |
|
|
|
Shape: |
|
- x: (B, h, w, embed_dim) |
|
- quant_list[i]: (B, h, w, embed_dim) |
|
- codes: (B, h, w, d) |
|
""" |
|
B, h, w, embed_dim = x.shape |
|
|
|
residual_feature = x.detach().clone() |
|
|
|
quant_list = [] |
|
code_list = [] |
|
aggregated_quants = torch.zeros_like(x) |
|
for i in range(self.code_shape[-1]): |
|
quant, code = self.codebooks[i](residual_feature) |
|
|
|
residual_feature.sub_(quant) |
|
aggregated_quants.add_(quant) |
|
|
|
quant_list.append(aggregated_quants.clone()) |
|
code_list.append(code.unsqueeze(-1)) |
|
|
|
codes = torch.cat(code_list, dim=-1) |
|
return quant_list, codes |
|
|
|
def forward(self, x): |
|
x_reshaped = self.to_code_shape(x) |
|
quant_list, codes = self.quantize(x_reshaped) |
|
|
|
commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list) |
|
quants_trunc = self.to_latent_shape(quant_list[-1]) |
|
quants_trunc = x + (quants_trunc - x).detach() |
|
|
|
return quants_trunc, commitment_loss, codes |
|
|
|
def compute_commitment_loss(self, x, quant_list): |
|
r""" |
|
Compute the commitment loss for the residual quantization. |
|
The loss is iteratively computed by aggregating quantized features. |
|
""" |
|
loss_list = [] |
|
|
|
for idx, quant in enumerate(quant_list): |
|
partial_loss = (x-quant.detach()).pow(2.0).mean() |
|
loss_list.append(partial_loss) |
|
|
|
commitment_loss = torch.mean(torch.stack(loss_list)) |
|
return commitment_loss |
|
|
|
@torch.no_grad() |
|
def embed_code(self, code): |
|
assert code.shape[1:] == self.code_shape |
|
|
|
code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) |
|
|
|
if self.shared_codebook: |
|
embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)] |
|
else: |
|
embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)] |
|
|
|
embeds = torch.cat(embeds, dim=-2).sum(-2) |
|
embeds = self.to_latent_shape(embeds) |
|
|
|
return embeds |
|
|
|
@torch.no_grad() |
|
def embed_code_with_depth(self, code, to_latent_shape=False): |
|
assert code.shape[-1] == self.code_shape[-1] |
|
|
|
code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) |
|
|
|
|
|
if self.shared_codebook: |
|
embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)] |
|
else: |
|
embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)] |
|
|
|
if to_latent_shape: |
|
embeds = [self.to_latent_shape(embed.squeeze(-2)).unsqueeze(-2) for embed in embeds] |
|
embeds = torch.cat(embeds, dim=-2) |
|
|
|
return embeds |
|
|
|
|
|
class MultiSelfAttention(nn.Module): |
|
""" |
|
Optimized by batched matmul operations |
|
""" |
|
|
|
def __init__(self, config: AttentionBlockConfig, mask=True): |
|
super().__init__() |
|
assert config.embed_dim % config.n_head == 0 |
|
|
|
self.key = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) |
|
self.query = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) |
|
self.value = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) |
|
|
|
self.attn_drop = nn.Dropout(config.attn_pdrop, inplace=False) |
|
self.resid_drop = nn.Dropout(config.resid_pdrop, inplace=True) |
|
|
|
self.proj = nn.Linear(config.embed_dim, config.embed_dim, config.attn_bias) |
|
|
|
self.n_head = config.n_head |
|
self.mask = mask |
|
|
|
def forward(self, x, caching=False, past_kv=None): |
|
(B, T, C) = x.shape |
|
|
|
if not caching: |
|
assert past_kv is None |
|
|
|
x = x.transpose(0, 1).contiguous() |
|
|
|
k = self.key(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) |
|
q = self.query(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) |
|
v = self.value(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) |
|
|
|
if past_kv is not None: |
|
past_key, past_value = past_kv |
|
k = torch.cat([past_key, k], dim=-2) |
|
v = torch.cat([past_value, v], dim=-2) |
|
T_past = past_key.shape[1] |
|
else: |
|
T_past = 0 |
|
|
|
if caching: |
|
present = torch.stack([k, v]) |
|
else: |
|
present = None |
|
|
|
att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) |
|
if self.mask: |
|
mask = torch.tril(torch.ones(T_past+T, T_past+T, device=x.device, dtype=torch.bool)) |
|
mask = mask.view(1, T_past+T, T_past+T) |
|
att = att.masked_fill(~mask[:, T_past:T_past+T, :T_past+T], float('-inf')) |
|
att = F.softmax(att, dim=-1) |
|
att = self.attn_drop(att) |
|
|
|
y = torch.bmm(att, v) |
|
y = y.transpose(0, 1).contiguous().view(T, B, C) |
|
|
|
y = self.resid_drop(self.proj(y)) |
|
|
|
if caching: |
|
return y.transpose(0, 1).contiguous(), present |
|
else: |
|
return y.transpose(0, 1).contiguous() |
|
|
|
class AttentionBlock(nn.Module): |
|
""" an unassuming Transformer block """ |
|
|
|
def __init__(self, config: AttentionBlockConfig): |
|
super().__init__() |
|
|
|
self.ln1 = nn.LayerNorm(config.embed_dim) |
|
self.ln2 = nn.LayerNorm(config.embed_dim) |
|
|
|
self.attn = MultiSelfAttention(config, mask=True) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=config.mlp_bias), |
|
nn.GELU(), |
|
nn.Linear(4 * config.embed_dim, config.embed_dim, bias=config.mlp_bias), |
|
nn.Dropout(config.resid_pdrop, inplace=True), |
|
) |
|
self._cache = None |
|
|
|
def forward(self, x): |
|
|
|
attn = self.attn(self.ln1(x)) |
|
|
|
x = x + attn |
|
x = x + self.mlp(self.ln2(x)) |
|
|
|
return x |
|
|
|
def cached_forward(self, x_present): |
|
|
|
attn, present = self.attn(self.ln1(x_present), caching=True, past_kv=self._cache['past_kv']) |
|
self._cache['past_kv'] = present |
|
|
|
x_present = x_present + attn |
|
x_present = x_present + self.mlp(self.ln2(x_present)) |
|
|
|
return x_present |
|
|
|
def init_cache(self): |
|
self._cache = {'past_kv': None} |
|
|
|
class AttentionStack(nn.Module): |
|
|
|
blocks: Iterable[AttentionBlock] |
|
|
|
def __init__(self, config: AttentionStackConfig): |
|
super().__init__() |
|
|
|
self.blocks = nn.ModuleList([AttentionBlock(config.block) for _ in range(config.n_layer)]) |
|
|
|
def forward(self, x): |
|
for block in self.blocks: |
|
x = block(x) |
|
|
|
return x |
|
|
|
def cached_forward(self, x_present): |
|
for block in self.blocks: |
|
x_present = block.cached_forward(x_present) |
|
|
|
return x_present |
|
|
|
def init_cache(self): |
|
for block in self.blocks: |
|
block.init_cache() |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
RMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
|
|
|
if self.weight.dtype in [torch.float16, torch.bfloat16]: |
|
hidden_states = hidden_states.to(self.weight.dtype) |
|
|
|
return self.weight * hidden_states |
|
|
|
|
|
class CasualDepthTransformerLayer(nn.Module): |
|
def __init__(self, config, depth): |
|
super().__init__() |
|
self.config = config |
|
embed_size = config.embed_dim |
|
num_heads = embed_size // 128 |
|
self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads,batch_first=True) |
|
self.layernorm1 = RMSNorm(embed_size) |
|
self.layernorm2 = RMSNorm(embed_size) |
|
self.linear1 = nn.Linear(embed_size * depth, 2 * embed_size) |
|
self.linear2 = nn.Linear(2 * embed_size * depth, embed_size) |
|
|
|
def forward(self, x): |
|
|
|
seq_len = x.size(1) |
|
|
|
|
|
res = x |
|
x = self.layernorm1(x) |
|
src_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device) |
|
_x, _ = self.self_attention(x, x, x, is_causal=True, attn_mask=src_mask) |
|
res = _x + res |
|
res = self.layernorm2(res) |
|
x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (2 * self.config.embed_dim, -1, self.config.embed_dim))) |
|
x = torch.nn.functional.gelu(x) |
|
x = torch.einsum('blt,dlt->bld', x, torch.reshape(self.linear2.weight, (self.config.embed_dim, -1, 2 * self.config.embed_dim))) |
|
return res + x |
|
|
|
|
|
class RQTransformer(PreTrainedModel): |
|
config_class = RQTransformerConfig |
|
def __init__(self, config: RQTransformerConfig): |
|
super().__init__(config) |
|
self.in_mlp_1 = nn.Linear(config.input_embed_dim_1, config.embed_dim) |
|
|
|
self.head_transformer = nn.ModuleList([ |
|
CasualDepthTransformerLayer(config, config.block_size[-1]) |
|
for _ in range(3) |
|
]) |
|
self.headnorm = RMSNorm(config.embed_dim) |
|
self.heads = nn.ModuleList([ |
|
nn.Linear(config.embed_dim, config.vocab_size) |
|
for i in range(config.block_size[-1]) |
|
]) |
|
self.gradient_checkpointing = True |
|
|
|
|
|
def embed_with_model_aux(self, code, model_aux, mode=None): |
|
|
|
xs_emb = model_aux.get_code_emb_with_depth(code, mode=mode) |
|
return xs_emb |
|
|
|
def forward(self, embed_from_body, code, model_aux=None, mode=None): |
|
B, seq_len, D = code.shape |
|
|
|
depth_ctx = self.embed_with_model_aux(code, model_aux, mode=mode) |
|
depth_ctx = torch.cumsum(depth_ctx, dim=-2) |
|
depth_ctx = self.in_mlp_1(depth_ctx) |
|
|
|
depth_ctx_full = torch.cat( |
|
[ |
|
embed_from_body.view(B, seq_len, 1, -1), |
|
depth_ctx[:, :, :-1, :], |
|
], |
|
dim=-2, |
|
) |
|
|
|
depth_ctx_full = depth_ctx_full.reshape(B * seq_len, D, -1) |
|
|
|
for i, tlayer in enumerate(self.head_transformer): |
|
if self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
|
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
depth_ctx_full = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(tlayer), depth_ctx_full, |
|
) |
|
else: |
|
depth_ctx_full = tlayer( |
|
depth_ctx_full, |
|
) |
|
|
|
depth_ctx_full = self.headnorm(depth_ctx_full) |
|
logits = [head(depth_ctx_full[:, i]) for i, head in enumerate(self.heads)] |
|
|
|
return logits |
|
|
|
def generate(self, embed_from_body, model_aux=None, cfg=3.0, mode=None): |
|
top_k = 900 |
|
|
|
|
|
top_p = 0.96 |
|
|
|
|
|
B, seq_len, _ = embed_from_body.shape |
|
next_token_ids = torch.zeros(B, 1, 4, dtype=torch.long).to(embed_from_body.device) |
|
for i in range(4): |
|
logits = self(embed_from_body, next_token_ids, model_aux, mode=mode) |
|
next_token_logits = logits[i].clone() |
|
|
|
next_token_logits = next_token_logits[B//2:, :] + cfg * (next_token_logits[:B//2, :] - next_token_logits[B//2:, :]) |
|
next_tokens = sample_from_logits(next_token_logits, temperature=1.0, top_p=top_p, top_k=top_k) |
|
next_tokens = next_tokens.reshape(B//2, seq_len).repeat(2, 1) |
|
next_token_ids[:, :, i] = next_tokens |
|
|
|
out_features = self.embed_with_model_aux(next_token_ids, model_aux, mode=mode) |
|
out_features = torch.cumsum(out_features, dim=-2)[:, :, -1, :] |
|
|
|
|
|
return out_features, next_token_ids |
|
|
|
|
|
def build_projector(dim_in, dim_out, projector_type='mlp2x_gelu'): |
|
if projector_type == 'linear': |
|
linear = nn.Linear(dim_in, dim_out) |
|
linear.reset_parameters() |
|
return linear |
|
elif projector_type == 'nonlinear': |
|
linear = nn.Linear(dim_in, dim_out) |
|
linear.reset_parameters() |
|
modules = [linear, nn.GELU()] |
|
return nn.Sequential(*modules) |
|
elif projector_type == 'conv': |
|
return nn.Conv2d(dim_in, dim_out, 1) |
|
else: |
|
linear_1 = nn.Linear(dim_in, dim_in) |
|
linear_1.reset_parameters() |
|
modules = [linear_1] |
|
modules.append(nn.GELU()) |
|
linear_2 = nn.Linear(dim_in, dim_out) |
|
linear_2.reset_parameters() |
|
modules.append(linear_2) |
|
|
|
return nn.Sequential(*modules) |
|
|
|
|
|
class RQVAESiglipModel(PreTrainedModel): |
|
config_class = RQVAESiglipConfig |
|
def __init__(self, config: RQVAESiglipConfig): |
|
super().__init__(config) |
|
|
|
self.config = config |
|
|
|
|
|
siglip_config = SiglipModel.config_class.from_pretrained(config.pretrained_model) |
|
self.siglip_model = SiglipModel._from_config(siglip_config) |
|
|
|
|
|
self.prequant_visual = ProjectResnetBlock(in_channels=config.hidden_size, |
|
out_channels=config.embed_dim, |
|
dropout=0.0) |
|
self.prequant_visual_1 = ProjectResnetBlock(in_channels=config.hidden_size, |
|
out_channels=config.embed_dim, |
|
dropout=0.0) |
|
self.layer_norm_visual = nn.LayerNorm(config.embed_dim) |
|
self.layer_norm_semantic = nn.LayerNorm(config.embed_dim) |
|
|
|
self.quantizer_semantic = RQBottleneck( |
|
latent_shape=config.latent_shape, |
|
code_shape=config.code_shape_semantic, |
|
n_embed=config.n_embed_semantic, |
|
decay=config.decay, |
|
shared_codebook=config.shared_codebook, |
|
restart_unused_codes=config.restart_unused_codes, |
|
) |
|
|
|
self.quantizer = RQBottleneck( |
|
latent_shape=config.latent_shape, |
|
code_shape=config.code_shape_visual, |
|
n_embed=config.n_embed_visual, |
|
decay=config.decay, |
|
shared_codebook=config.shared_codebook, |
|
restart_unused_codes=config.restart_unused_codes, |
|
) |
|
self.postquant_semantic = build_projector(config.embed_dim, config.hidden_size, projector_type='nonlinear') |
|
self.postquant_visual = ProjectResnetBlock(in_channels=config.embed_dim, |
|
out_channels=config.hidden_size, |
|
dropout=0.0) |
|
|
|
self.post_quant_conv = PostQuantResnetBlock(in_channels=config.hidden_size, |
|
out_channels=config.ddconfig["decoder_in_channels"], |
|
dropout=0.0) |
|
self.decoder = Decoder(**config.ddconfig) |
|
|
|
try: |
|
self.decoder_latent_shape = config.decoder_latent_shape |
|
except: |
|
self.decoder_latent_shape = None |
|
|
|
|
|
def encode_text(self, text): |
|
|
|
output_attentions, output_hidden_states, return_dict = None, None, None |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
text_model = self.siglip_model.text_model |
|
text_outputs = text_model( |
|
input_ids=text, |
|
attention_mask=None, |
|
position_ids=None, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
text_embeds = text_outputs[1] |
|
return text_embeds |
|
|
|
def encode_image(self, image): |
|
vision_model = self.siglip_model.vision_model |
|
hidden_states = vision_model.embeddings(image) |
|
|
|
attention_mask = None |
|
output_attentions = None |
|
|
|
visual_n, semantic_n = self.config.last_n_layer_recon, self.config.last_n_layer_sem |
|
for i, encoder_layer in enumerate(vision_model.encoder.layers): |
|
if vision_model.encoder.gradient_checkpointing and vision_model.encoder.training: |
|
layer_outputs = vision_model.encoder._gradient_checkpointing_func( |
|
encoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
output_attentions, |
|
) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if i == len(vision_model.encoder.layers) - (visual_n+2): |
|
B, L, C = hidden_states.shape |
|
hidden_states_visual_1 = hidden_states.reshape(B, int(L**0.5), int(L**0.5), -1) |
|
|
|
hidden_states_visual_1 = hidden_states_visual_1.permute(0, 3, 1, 2).contiguous() |
|
hidden_states_visual_1 = self.prequant_visual_1(hidden_states_visual_1) |
|
hidden_states_visual_1 = hidden_states_visual_1.permute(0, 2, 3, 1).contiguous() |
|
|
|
if i == len(vision_model.encoder.layers) - visual_n: |
|
B, L, C = hidden_states.shape |
|
hidden_states_visual = hidden_states.reshape(B, int(L**0.5), int(L**0.5), -1) |
|
|
|
hidden_states_visual = hidden_states_visual.permute(0, 3, 1, 2).contiguous() |
|
hidden_states_visual = self.prequant_visual(hidden_states_visual) |
|
hidden_states_visual = hidden_states_visual.permute(0, 2, 3, 1).contiguous() |
|
|
|
hidden_states_visual += 0.6 * hidden_states_visual_1 |
|
|
|
hidden_states_visual = self.layer_norm_visual(hidden_states_visual) |
|
z_q_visual, quant_loss_visual, code_visual = self.quantizer(hidden_states_visual) |
|
|
|
if i == len(vision_model.encoder.layers) - semantic_n: |
|
hidden_state_26 = hidden_states |
|
B, L, C = hidden_states.shape |
|
|
|
hidden_states_semantic = hidden_states.reshape(B, int(L**0.5), int(L**0.5), -1) |
|
hidden_states_semantic = self.layer_norm_semantic(hidden_states_semantic) |
|
z_q_semantic, quant_loss_semantic, code_semantic = self.quantizer_semantic(hidden_states_semantic) |
|
|
|
return z_q_visual, code_visual, z_q_semantic, code_semantic |
|
|
|
def decode(self, z_q): |
|
z_q = z_q.permute(0, 3, 1, 2).contiguous() |
|
z_q = self.postquant_visual(z_q) |
|
|
|
if self.decoder_latent_shape is not None: |
|
z_q = F.interpolate(z_q.to(torch.float32), size=tuple(self.decoder_latent_shape), mode='bilinear').to(torch.bfloat16) |
|
z_q = self.post_quant_conv(z_q) |
|
out = self.decoder(z_q) |
|
return out |
|
|
|
@torch.no_grad() |
|
def get_code_emb_with_depth(self, code, mode=None): |
|
|
|
if mode == "visual": |
|
visual_embedding = self.quantizer.embed_code_with_depth(code) |
|
return visual_embedding |
|
elif mode == "semantic": |
|
semantic_embedding = self.quantizer_semantic.embed_code_with_depth(code) |
|
return semantic_embedding |
|
|
|
|
|
class RQVAESIGLIPTransformer(PreTrainedModel): |
|
config_class = RQVAESIGLIPTransformerConfig |
|
def __init__(self, config: RQVAESIGLIPTransformerConfig): |
|
super().__init__(config) |
|
|
|
rqvaesiglip_config = RQVAESiglipModel.config_class.from_dict(config.rqvaesiglip) |
|
rqtransformer_visual_config = RQTransformer.config_class.from_dict(config.rqtransformer_visual) |
|
rqtransformer_semantic_config = RQTransformer.config_class.from_dict(config.rqtransformer_semantic) |
|
|
|
self.rqvaesiglip = RQVAESiglipModel._from_config(rqvaesiglip_config) |
|
self.rqtransformer_visual = RQTransformer._from_config(rqtransformer_visual_config) |
|
self.rqtransformer_semantic = RQTransformer._from_config(rqtransformer_semantic_config) |
|
|
|
|
|
class RQVAESIGLIPTransformerVisionTower(nn.Module): |
|
def __init__(self, model_name_or_path): |
|
super().__init__() |
|
model_dtype = torch.bfloat16 |
|
|
|
self.config = RQVAESIGLIPTransformerConfig.from_pretrained(model_name_or_path) |
|
self.vision_tower = RQVAESIGLIPTransformer._from_config(self.config, torch_dtype=model_dtype) |
|
self.is_loaded = True |
|
|
|
encoder_path = self.config.rqvaesiglip["pretrained_model"] |
|
if "siglip-so400m-patch14-384" in encoder_path: |
|
self.image_processor = CLIPImageProcessor( |
|
size={"height": 384, "width": 384}, |
|
crop_size={"height": 384, "width": 384}, |
|
image_mean=[0.5, 0.5, 0.5], |
|
image_std=[0.5, 0.5, 0.5] |
|
) |
|
self.image_tokens = 729 |
|
|
|
elif "siglip-large-patch16-256" in encoder_path: |
|
self.image_processor = CLIPImageProcessor( |
|
size={"height": 256, "width": 256}, |
|
crop_size={"height": 256, "width": 256}, |
|
image_mean=[0.5, 0.5, 0.5], |
|
image_std=[0.5, 0.5, 0.5] |
|
) |
|
self.image_tokens = 256 |
|
|
|
else: |
|
raise NotImplementedError() |
|
|
|
def encode(self, images: torch.Tensor): |
|
vision_output = self.vision_tower.rqvaesiglip.encode_image(images) |
|
|
|
image_features_visual, tokens_visual = vision_output[0], vision_output[1] |
|
image_features_semantic, tokens_semantic = vision_output[2], vision_output[3] |
|
|
|
bs, patch_size, _, dim = image_features_visual.shape |
|
image_features_visual = torch.reshape(image_features_visual, [bs, patch_size**2, dim]) |
|
tokens_visual = torch.reshape(tokens_visual, [bs, patch_size**2, -1]) |
|
|
|
bs, patch_size, _, dim = image_features_semantic.shape |
|
image_features_semantic = torch.reshape(image_features_semantic, [bs, patch_size**2, dim]) |
|
tokens_semantic = torch.reshape(tokens_semantic, [bs, patch_size**2, -1]) |
|
|
|
return image_features_visual, image_features_semantic, tokens_visual, tokens_semantic |
|
|
|
|
|
class BaichuanVisualEncoder(RQVAESIGLIPTransformerVisionTower): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.vision_tower.rqvaesiglip.siglip_model.vision_model.gradient_checkpointing = True |
|
self.vision_tower.rqvaesiglip.siglip_model.vision_model._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint |
|
|
|
def forward( |
|
self, |
|
pixel_values |
|
): |
|
pixel_values = pixel_values.to(self.vision_tower.rqvaesiglip.siglip_model.vision_model.embeddings.patch_embedding.weight.dtype) |
|
return self.encode(pixel_values) |
|
|
|
@torch.no_grad() |
|
def fake_input(self, input_ids, merge_size): |
|
fake_image = [torch.zeros([3, self.config.image_size, self.config.image_size], dtype=torch.float32, device=input_ids.device)] |
|
return fake_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
fire.Fire() |
|
|