u_st1 / visual_modeling_baichuan.py
Salmonnn's picture
Upload folder using huggingface_hub
2a5fb5f verified
from typing import List, Optional, Tuple, Union, Iterable, Any
import torch, math
import torch.utils.checkpoint
from torch import nn
from collections import OrderedDict
from flash_attn import flash_attn_varlen_func
from transformers.activations import ACT2FN
import io, fire
from torch.nn import functional as F
from transformers.modeling_utils import PreTrainedModel
from transformers import PreTrainedModel, CLIPImageProcessor, PretrainedConfig
from .configuration_baichuan import RQVAESIGLIPTransformerConfig, RQTransformerConfig, RQVAESiglipConfig, AttentionStackConfig, AttentionBlockConfig, SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from torch.utils.checkpoint import checkpoint
import warnings
from dataclasses import dataclass
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
import torch.distributed as dist
import numpy as np
def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
def top_p_probs(probs, p):
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_idx_remove_cond = cum_probs >= p
sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
sorted_idx_remove_cond[..., 0] = 0
indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
probs = probs.masked_fill(indices_to_remove, 0.0)
norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
return norm_probs
def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None):
"""Take a 2-dim tensor, apply softmax along each row, and sample from
each multinomial distribution defined by the rows.
Args:
logits: 2-dim tensor of shape (n_samples, logit_dim)
temperature (float): softmax temperature
top_k (Optional[int]): if given, sample only using `top_k` logits
top_p (Optional[float]): if given, sample only using `top_p` logits
Returns:
samples: 1-dim integer tensor of shape (n_samples,)
"""
logits = logits.to(dtype=torch.float32)
logits = logits / temperature
if top_k is not None:
logits = top_k_logits(logits, top_k)
if torch.sum(torch.isnan(logits)):
print('WARNING... NaN observed')
logits[torch.isnan(logits)] = -float('Inf')
probs = F.softmax(logits, dim=-1)
if top_p is not None:
probs = top_p_probs(probs, top_p)
try:
samples = torch.multinomial(probs, num_samples=1)
except:
raise RuntimeError
return samples.view(-1)
def nonlinearity(x):
return F.silu(x, inplace=True)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
""" PyTorch Siglip model."""
logger = logging.get_logger(__name__)
# _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
# SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
# "google/siglip-base-patch16-224",
# # See all SigLIP models at https://huggingface.co/models?filter=siglip
# ]
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
class SiglipVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
class SiglipTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
class SiglipOutput(ModelOutput):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`SiglipTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
The output of the [`SiglipVisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2 # 256//14**2=256
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) # 256, 1024
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
pixel_values = pixel_values.to(self.patch_embedding.weight.device) # self.patch_embedding.weight.dtype=torch.bfloat16
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] 3, 1024, 32, 32
embeddings = patch_embeds.flatten(2).transpose(1, 2) # 3, 1024, 1024
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
class SiglipTextEmbeddings(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class SiglipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SiglipConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = (
self.config.vision_config.hidden_size
if isinstance(self.config, SiglipConfig)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, SiglipModel):
logit_scale_init = torch.log(torch.tensor(1.0))
module.logit_scale.data.fill_(logit_scale_init)
module.logit_bias.data.zero_()
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
SIGLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
SIGLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(self, config: SiglipConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class SiglipTextTransformer(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipTextEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, embed_dim)
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
# if attention_mask is not None:
# # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
# attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# Assuming "sticky" EOS tokenization, last token is always EOS.
pooled_output = last_hidden_state[:, -1, :]
pooled_output = self.head(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""The text model from SigLIP without any head or projection on top.""",
SIGLIP_START_DOCSTRING,
)
class SiglipTextModel(SiglipPreTrainedModel):
config_class = SiglipTextConfig
_no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
def __init__(self, config: SiglipTextConfig):
super().__init__(config)
self.text_model = SiglipTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from transformers import AutoTokenizer, SiglipTextModel
>>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = SiglipMultiheadAttentionPoolingHead(config)
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = self.head(last_hidden_state)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
@add_start_docstrings(
"""The vision model from SigLIP without any head or projection on top.""",
SIGLIP_START_DOCSTRING,
)
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.vision_model = SiglipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, SiglipVisionModel
>>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@add_start_docstrings(SIGLIP_START_DOCSTRING)
class SiglipModel(SiglipPreTrainedModel):
config_class = SiglipConfig
def __init__(self, config: SiglipConfig):
super().__init__(config)
if not isinstance(config.text_config, SiglipTextConfig):
raise ValueError(
"config.text_config is expected to be of type SiglipTextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, SiglipVisionConfig):
raise ValueError(
"config.vision_config is expected to be of type SiglipVisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
self.text_model = SiglipTextTransformer(text_config)
self.vision_model = SiglipVisionTransformer(vision_config)
self.logit_scale = nn.Parameter(torch.randn(1))
self.logit_bias = nn.Parameter(torch.randn(1))
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`SiglipTextModel`].
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... text_features = model.get_text_features(**inputs)
```"""
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`SiglipVisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... image_features = model.get_image_features(**inputs)
```"""
# Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = vision_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SiglipOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
>>> inputs = processor(text=texts, images=image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
31.9% that image 0 is 'a photo of 2 cats'
```"""
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
text_embeds = text_outputs[1]
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
raise NotImplementedError("SigLIP loss to be implemented")
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return SiglipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x.to(torch.float32), scale_factor=2.0, mode="nearest").to(torch.bfloat16)
if self.with_conv:
x = self.conv(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.checkpointing = False
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def _forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
def forward(self, x, temb):
if self.checkpointing and self.training:
out = checkpoint(self._forward, x, temb)
else:
out = self._forward(x, temb)
return out
class PostQuantResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.checkpointing = False
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def _forward(self, x):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
def forward(self, x):
if self.checkpointing and self.training:
out = checkpoint(self._forward, x)
else:
out = self._forward(x)
return out
class ProjectResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.checkpointing = False
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def _forward(self, x):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
def forward(self, x):
if self.checkpointing and self.training:
out = checkpoint(self._forward, x)
else:
out = self._forward(x)
return out
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1)
k = k.reshape(b,c,h*w)
w_ = torch.bmm(q,k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1)
h_ = torch.bmm(v,w_)
h_ = h_.reshape(b,c,h,w)
h_ = self.proj_out(h_)
return x+h_
class VQEmbedding(nn.Embedding):
"""VQ embedding module with ema update."""
def __init__(self, n_embed, embed_dim, ema=False, decay=0.99, restart_unused_codes=True, eps=1e-5):
super().__init__(n_embed + 1, embed_dim, padding_idx=n_embed)
self.ema = ema
self.decay = decay
self.eps = eps
self.restart_unused_codes = restart_unused_codes
self.n_embed = n_embed
@torch.no_grad()
def compute_distances(self, inputs): # 12, 16, 16, 1024
codebook_t = self.weight[:-1, :].t() # 1024, 16384
(embed_dim, _) = codebook_t.shape
inputs_shape = inputs.shape
assert inputs_shape[-1] == embed_dim
inputs_flat = inputs.reshape(-1, embed_dim) # 3072, 1024
inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True) # 3072, 1
codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True) # 1, 16384
distances = torch.addmm(
inputs_norm_sq + codebook_t_norm_sq,
inputs_flat,
codebook_t,
alpha=-2.0,
)
distances = distances.reshape(*inputs_shape[:-1], -1)
return distances # 13, 16, 16, 16384
@torch.no_grad()
def find_nearest_embedding(self, inputs):
distances = self.compute_distances(inputs)
embed_idxs = distances.argmin(dim=-1)
return embed_idxs
def forward(self, inputs):
embed_idxs = self.find_nearest_embedding(inputs)
embeds = self.embed(embed_idxs)
return embeds, embed_idxs
def embed(self, idxs):
embeds = super().forward(idxs)
return embeds
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, decoder_in_channels, give_pre_end=False, **ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1, decoder_in_channels, curr_res, curr_res)
self.conv_in = torch.nn.Conv2d(decoder_in_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up)
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
self.last_z_shape = z.shape
temb = None
h = self.conv_in(z)
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class RQBottleneck(nn.Module):
"""
Quantization bottleneck via Residual Quantization.
Arguments:
latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D)
code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d)
n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook)
If isinstance(n_embed, int), the sizes of all codebooks are same.
shared_codebook (bool): If True, codebooks are shared in all location. If False,
uses separate codebooks along the ``depth'' dimension. (default: False)
restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch
as the new embedding of unused codes in training. (default: True)
"""
def __init__(self,
latent_shape,
code_shape,
n_embed,
decay=0.99,
shared_codebook=False,
restart_unused_codes=True,
commitment_loss='cumsum'
):
super().__init__()
if not len(code_shape) == len(latent_shape) == 3:
raise ValueError("incompatible code shape or latent shape")
if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]):
raise ValueError("incompatible code shape or latent shape")
embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2]
self.latent_shape = torch.Size(latent_shape)
self.code_shape = torch.Size(code_shape) # 16, 16, 4
self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))])
self.shared_codebook = shared_codebook
if self.shared_codebook:
if isinstance(n_embed, Iterable) or isinstance(decay, Iterable):
raise ValueError("Shared codebooks are incompatible \
with list types of momentums or sizes: Change it into int")
self.restart_unused_codes = restart_unused_codes
self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])] # [16384, 16384, 16384, 16384]
self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])]
assert len(self.n_embed) == self.code_shape[-1]
assert len(self.decay) == self.code_shape[-1]
if self.shared_codebook:
codebook0 = VQEmbedding(self.n_embed[0],
embed_dim,
decay=self.decay[0],
restart_unused_codes=restart_unused_codes,
)
self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])])
else:
codebooks = [VQEmbedding(self.n_embed[idx],
embed_dim,
decay=self.decay[idx],
restart_unused_codes=restart_unused_codes,
) for idx in range(self.code_shape[-1])]
self.codebooks = nn.ModuleList(codebooks)
self.commitment_loss = commitment_loss
def to_code_shape(self, x):
(B, H, W, D) = x.shape
(rH, rW, _) = self.shape_divisor
x = x.reshape(B, H//rH, rH, W//rW, rW, D)
x = x.permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, H//rH, W//rW, -1)
return x
def to_latent_shape(self, x):
(B, h, w, _) = x.shape
(_, _, D) = self.latent_shape
(rH, rW, _) = self.shape_divisor
x = x.reshape(B, h, w, rH, rW, D)
x = x.permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, h*rH, w*rW, D)
return x
def quantize(self, x):
r"""
Return list of quantized features and the selected codewords by the residual quantization.
The code is selected by the residuals between x and quantized features by the previous codebooks.
Arguments:
x (Tensor): bottleneck feature maps to quantize.
Returns:
quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks.
codes (LongTensor): codewords index, corresponding to quants.
Shape:
- x: (B, h, w, embed_dim)
- quant_list[i]: (B, h, w, embed_dim)
- codes: (B, h, w, d)
"""
B, h, w, embed_dim = x.shape
residual_feature = x.detach().clone() # 13, 16, 16, 1024
quant_list = []
code_list = []
aggregated_quants = torch.zeros_like(x)
for i in range(self.code_shape[-1]): # 4
quant, code = self.codebooks[i](residual_feature) # 13, 16, 16, 1024 13, 16, 16
residual_feature.sub_(quant)
aggregated_quants.add_(quant)
quant_list.append(aggregated_quants.clone())
code_list.append(code.unsqueeze(-1))
codes = torch.cat(code_list, dim=-1)
return quant_list, codes
def forward(self, x):
x_reshaped = self.to_code_shape(x)
quant_list, codes = self.quantize(x_reshaped)
commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list)
quants_trunc = self.to_latent_shape(quant_list[-1])
quants_trunc = x + (quants_trunc - x).detach()
return quants_trunc, commitment_loss, codes
def compute_commitment_loss(self, x, quant_list):
r"""
Compute the commitment loss for the residual quantization.
The loss is iteratively computed by aggregating quantized features.
"""
loss_list = []
for idx, quant in enumerate(quant_list):
partial_loss = (x-quant.detach()).pow(2.0).mean()
loss_list.append(partial_loss)
commitment_loss = torch.mean(torch.stack(loss_list))
return commitment_loss
@torch.no_grad()
def embed_code(self, code):
assert code.shape[1:] == self.code_shape
code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1)
if self.shared_codebook:
embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)]
else:
embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)]
embeds = torch.cat(embeds, dim=-2).sum(-2)
embeds = self.to_latent_shape(embeds)
return embeds
@torch.no_grad()
def embed_code_with_depth(self, code, to_latent_shape=False):
assert code.shape[-1] == self.code_shape[-1] # 4
code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1)
# print(f"self.shared_codebook: {self.shared_codebook}")
if self.shared_codebook:
embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)]
else:
embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)]
if to_latent_shape:
embeds = [self.to_latent_shape(embed.squeeze(-2)).unsqueeze(-2) for embed in embeds]
embeds = torch.cat(embeds, dim=-2)
return embeds # 16, 16, 4, 1024
class MultiSelfAttention(nn.Module):
"""
Optimized by batched matmul operations
"""
def __init__(self, config: AttentionBlockConfig, mask=True):
super().__init__()
assert config.embed_dim % config.n_head == 0
self.key = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias)
self.query = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias)
self.value = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias)
self.attn_drop = nn.Dropout(config.attn_pdrop, inplace=False)
self.resid_drop = nn.Dropout(config.resid_pdrop, inplace=True)
self.proj = nn.Linear(config.embed_dim, config.embed_dim, config.attn_bias)
self.n_head = config.n_head
self.mask = mask
def forward(self, x, caching=False, past_kv=None):
(B, T, C) = x.shape
if not caching:
assert past_kv is None
x = x.transpose(0, 1).contiguous()
k = self.key(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1)
q = self.query(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1)
v = self.value(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1)
if past_kv is not None:
past_key, past_value = past_kv
k = torch.cat([past_key, k], dim=-2)
v = torch.cat([past_value, v], dim=-2)
T_past = past_key.shape[1]
else:
T_past = 0
if caching:
present = torch.stack([k, v])
else:
present = None
att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
if self.mask:
mask = torch.tril(torch.ones(T_past+T, T_past+T, device=x.device, dtype=torch.bool))
mask = mask.view(1, T_past+T, T_past+T)
att = att.masked_fill(~mask[:, T_past:T_past+T, :T_past+T], float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = torch.bmm(att, v)
y = y.transpose(0, 1).contiguous().view(T, B, C)
y = self.resid_drop(self.proj(y))
if caching:
return y.transpose(0, 1).contiguous(), present
else:
return y.transpose(0, 1).contiguous()
class AttentionBlock(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config: AttentionBlockConfig):
super().__init__()
self.ln1 = nn.LayerNorm(config.embed_dim)
self.ln2 = nn.LayerNorm(config.embed_dim)
self.attn = MultiSelfAttention(config, mask=True)
self.mlp = nn.Sequential(
nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=config.mlp_bias),
nn.GELU(),
nn.Linear(4 * config.embed_dim, config.embed_dim, bias=config.mlp_bias),
nn.Dropout(config.resid_pdrop, inplace=True),
)
self._cache = None
def forward(self, x):
attn = self.attn(self.ln1(x))
x = x + attn
x = x + self.mlp(self.ln2(x))
return x
def cached_forward(self, x_present):
attn, present = self.attn(self.ln1(x_present), caching=True, past_kv=self._cache['past_kv'])
self._cache['past_kv'] = present
x_present = x_present + attn
x_present = x_present + self.mlp(self.ln2(x_present))
return x_present
def init_cache(self):
self._cache = {'past_kv': None}
class AttentionStack(nn.Module):
blocks: Iterable[AttentionBlock]
def __init__(self, config: AttentionStackConfig):
super().__init__()
self.blocks = nn.ModuleList([AttentionBlock(config.block) for _ in range(config.n_layer)])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
def cached_forward(self, x_present):
for block in self.blocks:
x_present = block.cached_forward(x_present)
return x_present
def init_cache(self):
for block in self.blocks:
block.init_cache()
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class CasualDepthTransformerLayer(nn.Module):
def __init__(self, config, depth):
super().__init__()
self.config = config
embed_size = config.embed_dim # 2048
num_heads = embed_size // 128 # 16
self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads,batch_first=True)
self.layernorm1 = RMSNorm(embed_size)
self.layernorm2 = RMSNorm(embed_size)
self.linear1 = nn.Linear(embed_size * depth, 2 * embed_size) # 8192, 4096
self.linear2 = nn.Linear(2 * embed_size * depth, embed_size)
def forward(self, x):
# 获取输入的序列长度
seq_len = x.size(1)
# 创建因果掩码,确保只能看到当前和过去的信息
# 自注意力层
res = x
x = self.layernorm1(x)
src_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
_x, _ = self.self_attention(x, x, x, is_causal=True, attn_mask=src_mask)
res = _x + res # (bs, sl, d)
res = self.layernorm2(res)
x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (2 * self.config.embed_dim, -1, self.config.embed_dim))) # linear1.reshape: 4096, 4, 2048
x = torch.nn.functional.gelu(x)
x = torch.einsum('blt,dlt->bld', x, torch.reshape(self.linear2.weight, (self.config.embed_dim, -1, 2 * self.config.embed_dim)))
return res + x
class RQTransformer(PreTrainedModel):
config_class = RQTransformerConfig
def __init__(self, config: RQTransformerConfig):
super().__init__(config)
self.in_mlp_1 = nn.Linear(config.input_embed_dim_1, config.embed_dim) # 1024, llm_hidden_size(2048)
self.head_transformer = nn.ModuleList([
CasualDepthTransformerLayer(config, config.block_size[-1])
for _ in range(3)
])
self.headnorm = RMSNorm(config.embed_dim)
self.heads = nn.ModuleList([
nn.Linear(config.embed_dim, config.vocab_size)
for i in range(config.block_size[-1])
])
self.gradient_checkpointing = True
def embed_with_model_aux(self, code, model_aux, mode=None):
# mode = "visual" or "semantic"
xs_emb = model_aux.get_code_emb_with_depth(code, mode=mode)
return xs_emb
def forward(self, embed_from_body, code, model_aux=None, mode=None):
B, seq_len, D = code.shape # 59, 256, 4
depth_ctx = self.embed_with_model_aux(code, model_aux, mode=mode)
depth_ctx = torch.cumsum(depth_ctx, dim=-2)
depth_ctx = self.in_mlp_1(depth_ctx) # torch.Size([59, 256, 4, llm_hidden_size(2048)])
depth_ctx_full = torch.cat(
[
embed_from_body.view(B, seq_len, 1, -1),
depth_ctx[:, :, :-1, :],
],
dim=-2,
) # B, 256, 4, 2048
depth_ctx_full = depth_ctx_full.reshape(B * seq_len, D, -1) # B*256, 4, 2048
for i, tlayer in enumerate(self.head_transformer):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
depth_ctx_full = torch.utils.checkpoint.checkpoint(
create_custom_forward(tlayer), depth_ctx_full,
)
else:
depth_ctx_full = tlayer(
depth_ctx_full,
)
depth_ctx_full = self.headnorm(depth_ctx_full) # B*256, 4, 2048
logits = [head(depth_ctx_full[:, i]) for i, head in enumerate(self.heads)] # logits[0].shape = B*256, 16384(codebook_size)
return logits
def generate(self, embed_from_body, model_aux=None, cfg=3.0, mode=None):
top_k = 900
# top_k = 500
# top_k = 1
top_p = 0.96
# top_p = 0.99
B, seq_len, _ = embed_from_body.shape # 1, 1, 2048
next_token_ids = torch.zeros(B, 1, 4, dtype=torch.long).to(embed_from_body.device)
for i in range(4):
logits = self(embed_from_body, next_token_ids, model_aux, mode=mode)
next_token_logits = logits[i].clone()
next_token_logits = next_token_logits[B//2:, :] + cfg * (next_token_logits[:B//2, :] - next_token_logits[B//2:, :])
next_tokens = sample_from_logits(next_token_logits, temperature=1.0, top_p=top_p, top_k=top_k)
next_tokens = next_tokens.reshape(B//2, seq_len).repeat(2, 1)
next_token_ids[:, :, i] = next_tokens
out_features = self.embed_with_model_aux(next_token_ids, model_aux, mode=mode)
out_features = torch.cumsum(out_features, dim=-2)[:, :, -1, :]
# out_features = self.in_mlp_1(out_features)
return out_features, next_token_ids
def build_projector(dim_in, dim_out, projector_type='mlp2x_gelu'):
if projector_type == 'linear':
linear = nn.Linear(dim_in, dim_out)
linear.reset_parameters()
return linear
elif projector_type == 'nonlinear':
linear = nn.Linear(dim_in, dim_out)
linear.reset_parameters()
modules = [linear, nn.GELU()]
return nn.Sequential(*modules)
elif projector_type == 'conv':
return nn.Conv2d(dim_in, dim_out, 1)
else: # mlp2x_gelu
linear_1 = nn.Linear(dim_in, dim_in)
linear_1.reset_parameters()
modules = [linear_1]
modules.append(nn.GELU())
linear_2 = nn.Linear(dim_in, dim_out)
linear_2.reset_parameters()
modules.append(linear_2)
return nn.Sequential(*modules)
class RQVAESiglipModel(PreTrainedModel):
config_class = RQVAESiglipConfig
def __init__(self, config: RQVAESiglipConfig):
super().__init__(config)
self.config = config
# self.siglip_model = SiglipModel.from_pretrained(config.pretrained_model)
siglip_config = SiglipModel.config_class.from_pretrained(config.pretrained_model)
self.siglip_model = SiglipModel._from_config(siglip_config)
# self.prequant_semantic = build_projector(config.hidden_size, config.embed_dim, projector_type='linear')
self.prequant_visual = ProjectResnetBlock(in_channels=config.hidden_size,
out_channels=config.embed_dim,
dropout=0.0)
self.prequant_visual_1 = ProjectResnetBlock(in_channels=config.hidden_size,
out_channels=config.embed_dim,
dropout=0.0)
self.layer_norm_visual = nn.LayerNorm(config.embed_dim)
self.layer_norm_semantic = nn.LayerNorm(config.embed_dim)
self.quantizer_semantic = RQBottleneck(
latent_shape=config.latent_shape,
code_shape=config.code_shape_semantic,
n_embed=config.n_embed_semantic,
decay=config.decay,
shared_codebook=config.shared_codebook,
restart_unused_codes=config.restart_unused_codes,
)
self.quantizer = RQBottleneck(
latent_shape=config.latent_shape,
code_shape=config.code_shape_visual,
n_embed=config.n_embed_visual,
decay=config.decay,
shared_codebook=config.shared_codebook,
restart_unused_codes=config.restart_unused_codes,
)
self.postquant_semantic = build_projector(config.embed_dim, config.hidden_size, projector_type='nonlinear')
self.postquant_visual = ProjectResnetBlock(in_channels=config.embed_dim,
out_channels=config.hidden_size,
dropout=0.0)
self.post_quant_conv = PostQuantResnetBlock(in_channels=config.hidden_size,
out_channels=config.ddconfig["decoder_in_channels"],
dropout=0.0)
self.decoder = Decoder(**config.ddconfig)
try:
self.decoder_latent_shape = config.decoder_latent_shape
except:
self.decoder_latent_shape = None
def encode_text(self, text):
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions, output_hidden_states, return_dict = None, None, None
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_model = self.siglip_model.text_model
text_outputs = text_model(
input_ids=text,
attention_mask=None,
position_ids=None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_embeds = text_outputs[1]
return text_embeds
def encode_image(self, image):
vision_model = self.siglip_model.vision_model
hidden_states = vision_model.embeddings(image)
attention_mask = None
output_attentions = None
# visual_n, semantic_n = 20, 2 # 取到倒数第n层的特征
visual_n, semantic_n = self.config.last_n_layer_recon, self.config.last_n_layer_sem
for i, encoder_layer in enumerate(vision_model.encoder.layers):
if vision_model.encoder.gradient_checkpointing and vision_model.encoder.training:
layer_outputs = vision_model.encoder._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if i == len(vision_model.encoder.layers) - (visual_n+2): # -22
B, L, C = hidden_states.shape
hidden_states_visual_1 = hidden_states.reshape(B, int(L**0.5), int(L**0.5), -1)
hidden_states_visual_1 = hidden_states_visual_1.permute(0, 3, 1, 2).contiguous()
hidden_states_visual_1 = self.prequant_visual_1(hidden_states_visual_1)
hidden_states_visual_1 = hidden_states_visual_1.permute(0, 2, 3, 1).contiguous()
if i == len(vision_model.encoder.layers) - visual_n: # -20
B, L, C = hidden_states.shape
hidden_states_visual = hidden_states.reshape(B, int(L**0.5), int(L**0.5), -1)
hidden_states_visual = hidden_states_visual.permute(0, 3, 1, 2).contiguous()
hidden_states_visual = self.prequant_visual(hidden_states_visual)
hidden_states_visual = hidden_states_visual.permute(0, 2, 3, 1).contiguous()
hidden_states_visual += 0.6 * hidden_states_visual_1
hidden_states_visual = self.layer_norm_visual(hidden_states_visual)
z_q_visual, quant_loss_visual, code_visual = self.quantizer(hidden_states_visual)
if i == len(vision_model.encoder.layers) - semantic_n:
hidden_state_26 = hidden_states
B, L, C = hidden_states.shape
hidden_states_semantic = hidden_states.reshape(B, int(L**0.5), int(L**0.5), -1)
hidden_states_semantic = self.layer_norm_semantic(hidden_states_semantic)
z_q_semantic, quant_loss_semantic, code_semantic = self.quantizer_semantic(hidden_states_semantic)
return z_q_visual, code_visual, z_q_semantic, code_semantic
def decode(self, z_q):
z_q = z_q.permute(0, 3, 1, 2).contiguous()
z_q = self.postquant_visual(z_q)
if self.decoder_latent_shape is not None:
z_q = F.interpolate(z_q.to(torch.float32), size=tuple(self.decoder_latent_shape), mode='bilinear').to(torch.bfloat16)
z_q = self.post_quant_conv(z_q)
out = self.decoder(z_q)
return out
@torch.no_grad()
def get_code_emb_with_depth(self, code, mode=None):
# 分 visual codebook 取 image embedding, mode = "visual" or "semantic"
if mode == "visual":
visual_embedding = self.quantizer.embed_code_with_depth(code)
return visual_embedding
elif mode == "semantic":
semantic_embedding = self.quantizer_semantic.embed_code_with_depth(code)
return semantic_embedding
class RQVAESIGLIPTransformer(PreTrainedModel):
config_class = RQVAESIGLIPTransformerConfig
def __init__(self, config: RQVAESIGLIPTransformerConfig):
super().__init__(config)
rqvaesiglip_config = RQVAESiglipModel.config_class.from_dict(config.rqvaesiglip)
rqtransformer_visual_config = RQTransformer.config_class.from_dict(config.rqtransformer_visual)
rqtransformer_semantic_config = RQTransformer.config_class.from_dict(config.rqtransformer_semantic)
self.rqvaesiglip = RQVAESiglipModel._from_config(rqvaesiglip_config)
self.rqtransformer_visual = RQTransformer._from_config(rqtransformer_visual_config)
self.rqtransformer_semantic = RQTransformer._from_config(rqtransformer_semantic_config)
class RQVAESIGLIPTransformerVisionTower(nn.Module):
def __init__(self, model_name_or_path):
super().__init__()
model_dtype = torch.bfloat16
self.config = RQVAESIGLIPTransformerConfig.from_pretrained(model_name_or_path)
self.vision_tower = RQVAESIGLIPTransformer._from_config(self.config, torch_dtype=model_dtype)
self.is_loaded = True
encoder_path = self.config.rqvaesiglip["pretrained_model"]
if "siglip-so400m-patch14-384" in encoder_path: # SigLIP-SO400M-patch14-384
self.image_processor = CLIPImageProcessor(
size={"height": 384, "width": 384},
crop_size={"height": 384, "width": 384},
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5]
)
self.image_tokens = 729
# self.config.hidden_size == 1152
elif "siglip-large-patch16-256" in encoder_path: # SigLIP-Large-patch16-256
self.image_processor = CLIPImageProcessor(
size={"height": 256, "width": 256},
crop_size={"height": 256, "width": 256},
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5]
)
self.image_tokens = 256
# self.config.hidden_size == 1024
else:
raise NotImplementedError()
def encode(self, images: torch.Tensor):
vision_output = self.vision_tower.rqvaesiglip.encode_image(images)
image_features_visual, tokens_visual = vision_output[0], vision_output[1]
image_features_semantic, tokens_semantic = vision_output[2], vision_output[3]
bs, patch_size, _, dim = image_features_visual.shape
image_features_visual = torch.reshape(image_features_visual, [bs, patch_size**2, dim])
tokens_visual = torch.reshape(tokens_visual, [bs, patch_size**2, -1])
bs, patch_size, _, dim = image_features_semantic.shape
image_features_semantic = torch.reshape(image_features_semantic, [bs, patch_size**2, dim])
tokens_semantic = torch.reshape(tokens_semantic, [bs, patch_size**2, -1])
return image_features_visual, image_features_semantic, tokens_visual, tokens_semantic
class BaichuanVisualEncoder(RQVAESIGLIPTransformerVisionTower):
def __init__(self, config):
super().__init__(config)
self.vision_tower.rqvaesiglip.siglip_model.vision_model.gradient_checkpointing = True # 强制开启
self.vision_tower.rqvaesiglip.siglip_model.vision_model._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
def forward(
self,
pixel_values # pixel_values = 179, 161, 147..., shape=13, 3, 256, 256
):
pixel_values = pixel_values.to(self.vision_tower.rqvaesiglip.siglip_model.vision_model.embeddings.patch_embedding.weight.dtype)
return self.encode(pixel_values) # visual_idx (not add text_vocab)
@torch.no_grad()
def fake_input(self, input_ids, merge_size):
fake_image = [torch.zeros([3, self.config.image_size, self.config.image_size], dtype=torch.float32, device=input_ids.device)]
return fake_image
# def test_vision():
# from transformers.models.clip.modeling_clip import CLIPPreTrainedModel
# from transformers import AutoConfig
# config = AutoConfig.from_pretrained("./", trust_remote_code=True)
# ae = BaichuanVisualEncoder(config).cuda().to(torch.bfloat16)
# bg = BaichuanVisualBridge(config).cuda().to(torch.bfloat16)
# print(ae)
# pixel_input = torch.rand([4, 3, config.image_size, config.image_size], dtype=torch.float32).cuda()
# visual_embedding = ae(pixel_input)[0][:, 1:] # 删除class token
# visual_proj = bg(visual_embedding)
# print(visual_proj.shape)
# print(ae.fake_input(visual_proj.device))
if __name__ == '__main__':
fire.Fire()