from typing import List, Optional, Tuple, Union, Iterable, Any import torch, math import torch.utils.checkpoint from torch import nn from collections import OrderedDict from flash_attn import flash_attn_varlen_func from transformers.activations import ACT2FN import io, fire from torch.nn import functional as F from transformers.modeling_utils import PreTrainedModel from transformers import PreTrainedModel, CLIPImageProcessor, PretrainedConfig from .configuration_baichuan import RQVAESIGLIPTransformerConfig, RQTransformerConfig, RQVAESiglipConfig, AttentionStackConfig, AttentionBlockConfig, SiglipConfig, SiglipTextConfig, SiglipVisionConfig from torch.utils.checkpoint import checkpoint import warnings from dataclasses import dataclass from torch.nn.init import _calculate_fan_in_and_fan_out from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) import torch.distributed as dist import numpy as np def top_k_logits(logits, k): v, ix = torch.topk(logits, k) out = logits.clone() out[out < v[:, [-1]]] = -float('Inf') return out def top_p_probs(probs, p): sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) cum_probs = torch.cumsum(sorted_probs, dim=-1) sorted_idx_remove_cond = cum_probs >= p sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() sorted_idx_remove_cond[..., 0] = 0 indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) probs = probs.masked_fill(indices_to_remove, 0.0) norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True) return norm_probs def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None): """Take a 2-dim tensor, apply softmax along each row, and sample from each multinomial distribution defined by the rows. Args: logits: 2-dim tensor of shape (n_samples, logit_dim) temperature (float): softmax temperature top_k (Optional[int]): if given, sample only using `top_k` logits top_p (Optional[float]): if given, sample only using `top_p` logits Returns: samples: 1-dim integer tensor of shape (n_samples,) """ logits = logits.to(dtype=torch.float32) logits = logits / temperature if top_k is not None: logits = top_k_logits(logits, top_k) if torch.sum(torch.isnan(logits)): print('WARNING... NaN observed') logits[torch.isnan(logits)] = -float('Inf') probs = F.softmax(logits, dim=-1) if top_p is not None: probs = top_p_probs(probs, top_p) try: samples = torch.multinomial(probs, num_samples=1) except: raise RuntimeError return samples.view(-1) def nonlinearity(x): return F.silu(x, inplace=True) def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) """ PyTorch Siglip model.""" logger = logging.get_logger(__name__) # _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" # SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ # "google/siglip-base-patch16-224", # # See all SigLIP models at https://huggingface.co/models?filter=siglip # ] def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 ) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsquently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") @dataclass # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip class SiglipVisionModelOutput(ModelOutput): """ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. Args: image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The image embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ image_embeds: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip class SiglipTextModelOutput(ModelOutput): """ Base class for text model's outputs that also contains a pooling of the last hidden states. Args: text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ text_embeds: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip class SiglipOutput(ModelOutput): """ Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): Contrastive loss for image-text similarity. logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text similarity scores. logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image similarity scores. text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. text_model_output(`BaseModelOutputWithPooling`): The output of the [`SiglipTextModel`]. vision_model_output(`BaseModelOutputWithPooling`): The output of the [`SiglipVisionModel`]. """ loss: Optional[torch.FloatTensor] = None logits_per_image: torch.FloatTensor = None logits_per_text: torch.FloatTensor = None text_embeds: torch.FloatTensor = None image_embeds: torch.FloatTensor = None text_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPooling = None def to_tuple(self) -> Tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() ) class SiglipVisionEmbeddings(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches = (self.image_size // self.patch_size) ** 2 # 256//14**2=256 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) # 256, 1024 self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: pixel_values = pixel_values.to(self.patch_embedding.weight.device) # self.patch_embedding.weight.dtype=torch.bfloat16 patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] 3, 1024, 32, 32 embeddings = patch_embeds.flatten(2).transpose(1, 2) # 3, 1024, 1024 embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip class SiglipTextEmbeddings(nn.Module): def __init__(self, config: SiglipTextConfig): super().__init__() embed_dim = config.hidden_size self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) def forward( self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings class SiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) k_v_seq_len = key_states.shape[-2] attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): raise ValueError( f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): raise ValueError( f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip class SiglipMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip class SiglipEncoderLayer(nn.Module): def __init__(self, config: SiglipConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = SiglipAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) # Ignore copy def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class SiglipPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = SiglipConfig base_model_prefix = "siglip" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SiglipVisionEmbeddings): width = ( self.config.vision_config.hidden_size if isinstance(self.config, SiglipConfig) else self.config.hidden_size ) nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, SiglipAttention): nn.init.xavier_uniform_(module.q_proj.weight) nn.init.xavier_uniform_(module.k_proj.weight) nn.init.xavier_uniform_(module.v_proj.weight) nn.init.xavier_uniform_(module.out_proj.weight) nn.init.zeros_(module.q_proj.bias) nn.init.zeros_(module.k_proj.bias) nn.init.zeros_(module.v_proj.bias) nn.init.zeros_(module.out_proj.bias) elif isinstance(module, SiglipMLP): nn.init.xavier_uniform_(module.fc1.weight) nn.init.xavier_uniform_(module.fc2.weight) nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, SiglipMultiheadAttentionPoolingHead): nn.init.xavier_uniform_(module.probe.data) nn.init.xavier_uniform_(module.attention.in_proj_weight.data) nn.init.zeros_(module.attention.in_proj_bias.data) elif isinstance(module, SiglipModel): logit_scale_init = torch.log(torch.tensor(1.0)) module.logit_scale.data.fill_(logit_scale_init) module.logit_bias.data.zero_() elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) SIGLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ SIGLIP_TEXT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ SIGLIP_VISION_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ SIGLIP_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. return_loss (`bool`, *optional*): Whether or not to return the contrastive loss. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip class SiglipEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`SiglipEncoderLayer`]. Args: config: SiglipConfig """ def __init__(self, config: SiglipConfig): super().__init__() self.config = config self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False # Ignore copy def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class SiglipTextTransformer(nn.Module): def __init__(self, config: SiglipTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipTextEmbeddings(config) self.encoder = SiglipEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.head = nn.Linear(embed_dim, embed_dim) @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is None: raise ValueError("You have to specify input_ids") input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. # expand attention_mask # if attention_mask is not None: # # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] # attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) # Assuming "sticky" EOS tokenization, last token is always EOS. pooled_output = last_hidden_state[:, -1, :] pooled_output = self.head(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @add_start_docstrings( """The text model from SigLIP without any head or projection on top.""", SIGLIP_START_DOCSTRING, ) class SiglipTextModel(SiglipPreTrainedModel): config_class = SiglipTextConfig _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] def __init__(self, config: SiglipTextConfig): super().__init__(config) self.text_model = SiglipTextTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.embeddings.token_embedding def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: Examples: ```python >>> from transformers import AutoTokenizer, SiglipTextModel >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") >>> # important: make sure to set padding="max_length" as that's how the model was trained >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class SiglipVisionTransformer(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.head = SiglipMultiheadAttentionPoolingHead(config) @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) def forward( self, pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states = self.embeddings(pixel_values) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = self.head(last_hidden_state) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class SiglipMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" def __init__(self, config: SiglipVisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) def forward(self, hidden_state): batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) hidden_state = self.attention(probe, hidden_state, hidden_state)[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0] @add_start_docstrings( """The vision model from SigLIP without any head or projection on top.""", SIGLIP_START_DOCSTRING, ) class SiglipVisionModel(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" def __init__(self, config: SiglipVisionConfig): super().__init__(config) self.vision_model = SiglipVisionTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) def forward( self, pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, SiglipVisionModel >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) @add_start_docstrings(SIGLIP_START_DOCSTRING) class SiglipModel(SiglipPreTrainedModel): config_class = SiglipConfig def __init__(self, config: SiglipConfig): super().__init__(config) if not isinstance(config.text_config, SiglipTextConfig): raise ValueError( "config.text_config is expected to be of type SiglipTextConfig but is of type" f" {type(config.text_config)}." ) if not isinstance(config.vision_config, SiglipVisionConfig): raise ValueError( "config.vision_config is expected to be of type SiglipVisionConfig but is of type" f" {type(config.vision_config)}." ) text_config = config.text_config vision_config = config.vision_config self.text_model = SiglipTextTransformer(text_config) self.vision_model = SiglipVisionTransformer(vision_config) self.logit_scale = nn.Parameter(torch.randn(1)) self.logit_bias = nn.Parameter(torch.randn(1)) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) def get_text_features( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. Examples: ```python >>> from transformers import AutoTokenizer, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") >>> # important: make sure to set padding="max_length" as that's how the model was trained >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") >>> with torch.no_grad(): ... text_features = model.get_text_features(**inputs) ```""" # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = text_outputs[1] return pooled_output @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... image_features = model.get_image_features(**inputs) ```""" # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = vision_outputs[1] return pooled_output @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SiglipOutput]: r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] >>> inputs = processor(text=texts, images=image, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> logits_per_image = outputs.logits_per_image >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") 31.9% that image 0 is 'a photo of 2 cats' ```""" # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[1] text_embeds = text_outputs[1] # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias logits_per_image = logits_per_text.t() loss = None if return_loss: raise NotImplementedError("SigLIP loss to be implemented") if not return_dict: output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) return ((loss,) + output) if loss is not None else output return SiglipOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, ) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = torch.nn.functional.interpolate(x.to(torch.float32), scale_factor=2.0, mode="nearest").to(torch.bfloat16) if self.with_conv: x = self.conv(x) return x class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.checkpointing = False self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def _forward(self, x, temb): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x+h def forward(self, x, temb): if self.checkpointing and self.training: out = checkpoint(self._forward, x, temb) else: out = self._forward(x, temb) return out class PostQuantResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.checkpointing = False self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def _forward(self, x): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x+h def forward(self, x): if self.checkpointing and self.training: out = checkpoint(self._forward, x) else: out = self._forward(x) return out class ProjectResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.checkpointing = False self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def _forward(self, x): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x+h def forward(self, x): if self.checkpointing and self.training: out = checkpoint(self._forward, x) else: out = self._forward(x) return out class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b,c,h,w = q.shape q = q.reshape(b,c,h*w) q = q.permute(0,2,1) k = k.reshape(b,c,h*w) w_ = torch.bmm(q,k) w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) v = v.reshape(b,c,h*w) w_ = w_.permute(0,2,1) h_ = torch.bmm(v,w_) h_ = h_.reshape(b,c,h,w) h_ = self.proj_out(h_) return x+h_ class VQEmbedding(nn.Embedding): """VQ embedding module with ema update.""" def __init__(self, n_embed, embed_dim, ema=False, decay=0.99, restart_unused_codes=True, eps=1e-5): super().__init__(n_embed + 1, embed_dim, padding_idx=n_embed) self.ema = ema self.decay = decay self.eps = eps self.restart_unused_codes = restart_unused_codes self.n_embed = n_embed @torch.no_grad() def compute_distances(self, inputs): # 12, 16, 16, 1024 codebook_t = self.weight[:-1, :].t() # 1024, 16384 (embed_dim, _) = codebook_t.shape inputs_shape = inputs.shape assert inputs_shape[-1] == embed_dim inputs_flat = inputs.reshape(-1, embed_dim) # 3072, 1024 inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True) # 3072, 1 codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True) # 1, 16384 distances = torch.addmm( inputs_norm_sq + codebook_t_norm_sq, inputs_flat, codebook_t, alpha=-2.0, ) distances = distances.reshape(*inputs_shape[:-1], -1) return distances # 13, 16, 16, 16384 @torch.no_grad() def find_nearest_embedding(self, inputs): distances = self.compute_distances(inputs) embed_idxs = distances.argmin(dim=-1) return embed_idxs def forward(self, inputs): embed_idxs = self.find_nearest_embedding(inputs) embeds = self.embed(embed_idxs) return embeds, embed_idxs def embed(self, idxs): embeds = super().forward(idxs) return embeds class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, decoder_in_channels, give_pre_end=False, **ignorekwargs): super().__init__() self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end in_ch_mult = (1,)+tuple(ch_mult) block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1, decoder_in_channels, curr_res, curr_res) self.conv_in = torch.nn.Conv2d(decoder_in_channels, block_in, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks+1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(AttnBlock(block_in)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): self.last_z_shape = z.shape temb = None h = self.conv_in(z) h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) if self.give_pre_end: return h h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class RQBottleneck(nn.Module): """ Quantization bottleneck via Residual Quantization. Arguments: latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D) code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d) n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook) If isinstance(n_embed, int), the sizes of all codebooks are same. shared_codebook (bool): If True, codebooks are shared in all location. If False, uses separate codebooks along the ``depth'' dimension. (default: False) restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch as the new embedding of unused codes in training. (default: True) """ def __init__(self, latent_shape, code_shape, n_embed, decay=0.99, shared_codebook=False, restart_unused_codes=True, commitment_loss='cumsum' ): super().__init__() if not len(code_shape) == len(latent_shape) == 3: raise ValueError("incompatible code shape or latent shape") if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]): raise ValueError("incompatible code shape or latent shape") embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2] self.latent_shape = torch.Size(latent_shape) self.code_shape = torch.Size(code_shape) # 16, 16, 4 self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))]) self.shared_codebook = shared_codebook if self.shared_codebook: if isinstance(n_embed, Iterable) or isinstance(decay, Iterable): raise ValueError("Shared codebooks are incompatible \ with list types of momentums or sizes: Change it into int") self.restart_unused_codes = restart_unused_codes self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])] # [16384, 16384, 16384, 16384] self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])] assert len(self.n_embed) == self.code_shape[-1] assert len(self.decay) == self.code_shape[-1] if self.shared_codebook: codebook0 = VQEmbedding(self.n_embed[0], embed_dim, decay=self.decay[0], restart_unused_codes=restart_unused_codes, ) self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])]) else: codebooks = [VQEmbedding(self.n_embed[idx], embed_dim, decay=self.decay[idx], restart_unused_codes=restart_unused_codes, ) for idx in range(self.code_shape[-1])] self.codebooks = nn.ModuleList(codebooks) self.commitment_loss = commitment_loss def to_code_shape(self, x): (B, H, W, D) = x.shape (rH, rW, _) = self.shape_divisor x = x.reshape(B, H//rH, rH, W//rW, rW, D) x = x.permute(0, 1, 3, 2, 4, 5) x = x.reshape(B, H//rH, W//rW, -1) return x def to_latent_shape(self, x): (B, h, w, _) = x.shape (_, _, D) = self.latent_shape (rH, rW, _) = self.shape_divisor x = x.reshape(B, h, w, rH, rW, D) x = x.permute(0, 1, 3, 2, 4, 5) x = x.reshape(B, h*rH, w*rW, D) return x def quantize(self, x): r""" Return list of quantized features and the selected codewords by the residual quantization. The code is selected by the residuals between x and quantized features by the previous codebooks. Arguments: x (Tensor): bottleneck feature maps to quantize. Returns: quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks. codes (LongTensor): codewords index, corresponding to quants. Shape: - x: (B, h, w, embed_dim) - quant_list[i]: (B, h, w, embed_dim) - codes: (B, h, w, d) """ B, h, w, embed_dim = x.shape residual_feature = x.detach().clone() # 13, 16, 16, 1024 quant_list = [] code_list = [] aggregated_quants = torch.zeros_like(x) for i in range(self.code_shape[-1]): # 4 quant, code = self.codebooks[i](residual_feature) # 13, 16, 16, 1024 13, 16, 16 residual_feature.sub_(quant) aggregated_quants.add_(quant) quant_list.append(aggregated_quants.clone()) code_list.append(code.unsqueeze(-1)) codes = torch.cat(code_list, dim=-1) return quant_list, codes def forward(self, x): x_reshaped = self.to_code_shape(x) quant_list, codes = self.quantize(x_reshaped) commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list) quants_trunc = self.to_latent_shape(quant_list[-1]) quants_trunc = x + (quants_trunc - x).detach() return quants_trunc, commitment_loss, codes def compute_commitment_loss(self, x, quant_list): r""" Compute the commitment loss for the residual quantization. The loss is iteratively computed by aggregating quantized features. """ loss_list = [] for idx, quant in enumerate(quant_list): partial_loss = (x-quant.detach()).pow(2.0).mean() loss_list.append(partial_loss) commitment_loss = torch.mean(torch.stack(loss_list)) return commitment_loss @torch.no_grad() def embed_code(self, code): assert code.shape[1:] == self.code_shape code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) if self.shared_codebook: embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)] else: embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)] embeds = torch.cat(embeds, dim=-2).sum(-2) embeds = self.to_latent_shape(embeds) return embeds @torch.no_grad() def embed_code_with_depth(self, code, to_latent_shape=False): assert code.shape[-1] == self.code_shape[-1] # 4 code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) # print(f"self.shared_codebook: {self.shared_codebook}") if self.shared_codebook: embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)] else: embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)] if to_latent_shape: embeds = [self.to_latent_shape(embed.squeeze(-2)).unsqueeze(-2) for embed in embeds] embeds = torch.cat(embeds, dim=-2) return embeds # 16, 16, 4, 1024 class MultiSelfAttention(nn.Module): """ Optimized by batched matmul operations """ def __init__(self, config: AttentionBlockConfig, mask=True): super().__init__() assert config.embed_dim % config.n_head == 0 self.key = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) self.query = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) self.value = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) self.attn_drop = nn.Dropout(config.attn_pdrop, inplace=False) self.resid_drop = nn.Dropout(config.resid_pdrop, inplace=True) self.proj = nn.Linear(config.embed_dim, config.embed_dim, config.attn_bias) self.n_head = config.n_head self.mask = mask def forward(self, x, caching=False, past_kv=None): (B, T, C) = x.shape if not caching: assert past_kv is None x = x.transpose(0, 1).contiguous() k = self.key(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) q = self.query(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) v = self.value(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) if past_kv is not None: past_key, past_value = past_kv k = torch.cat([past_key, k], dim=-2) v = torch.cat([past_value, v], dim=-2) T_past = past_key.shape[1] else: T_past = 0 if caching: present = torch.stack([k, v]) else: present = None att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) if self.mask: mask = torch.tril(torch.ones(T_past+T, T_past+T, device=x.device, dtype=torch.bool)) mask = mask.view(1, T_past+T, T_past+T) att = att.masked_fill(~mask[:, T_past:T_past+T, :T_past+T], float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_drop(att) y = torch.bmm(att, v) y = y.transpose(0, 1).contiguous().view(T, B, C) y = self.resid_drop(self.proj(y)) if caching: return y.transpose(0, 1).contiguous(), present else: return y.transpose(0, 1).contiguous() class AttentionBlock(nn.Module): """ an unassuming Transformer block """ def __init__(self, config: AttentionBlockConfig): super().__init__() self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = MultiSelfAttention(config, mask=True) self.mlp = nn.Sequential( nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=config.mlp_bias), nn.GELU(), nn.Linear(4 * config.embed_dim, config.embed_dim, bias=config.mlp_bias), nn.Dropout(config.resid_pdrop, inplace=True), ) self._cache = None def forward(self, x): attn = self.attn(self.ln1(x)) x = x + attn x = x + self.mlp(self.ln2(x)) return x def cached_forward(self, x_present): attn, present = self.attn(self.ln1(x_present), caching=True, past_kv=self._cache['past_kv']) self._cache['past_kv'] = present x_present = x_present + attn x_present = x_present + self.mlp(self.ln2(x_present)) return x_present def init_cache(self): self._cache = {'past_kv': None} class AttentionStack(nn.Module): blocks: Iterable[AttentionBlock] def __init__(self, config: AttentionStackConfig): super().__init__() self.blocks = nn.ModuleList([AttentionBlock(config.block) for _ in range(config.n_layer)]) def forward(self, x): for block in self.blocks: x = block(x) return x def cached_forward(self, x_present): for block in self.blocks: x_present = block.cached_forward(x_present) return x_present def init_cache(self): for block in self.blocks: block.init_cache() class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states class CasualDepthTransformerLayer(nn.Module): def __init__(self, config, depth): super().__init__() self.config = config embed_size = config.embed_dim # 2048 num_heads = embed_size // 128 # 16 self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads,batch_first=True) self.layernorm1 = RMSNorm(embed_size) self.layernorm2 = RMSNorm(embed_size) self.linear1 = nn.Linear(embed_size * depth, 2 * embed_size) # 8192, 4096 self.linear2 = nn.Linear(2 * embed_size * depth, embed_size) def forward(self, x): # 获取输入的序列长度 seq_len = x.size(1) # 创建因果掩码,确保只能看到当前和过去的信息 # 自注意力层 res = x x = self.layernorm1(x) src_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device) _x, _ = self.self_attention(x, x, x, is_causal=True, attn_mask=src_mask) res = _x + res # (bs, sl, d) res = self.layernorm2(res) x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (2 * self.config.embed_dim, -1, self.config.embed_dim))) # linear1.reshape: 4096, 4, 2048 x = torch.nn.functional.gelu(x) x = torch.einsum('blt,dlt->bld', x, torch.reshape(self.linear2.weight, (self.config.embed_dim, -1, 2 * self.config.embed_dim))) return res + x class RQTransformer(PreTrainedModel): config_class = RQTransformerConfig def __init__(self, config: RQTransformerConfig): super().__init__(config) self.in_mlp_1 = nn.Linear(config.input_embed_dim_1, config.embed_dim) # 1024, llm_hidden_size(2048) self.head_transformer = nn.ModuleList([ CasualDepthTransformerLayer(config, config.block_size[-1]) for _ in range(3) ]) self.headnorm = RMSNorm(config.embed_dim) self.heads = nn.ModuleList([ nn.Linear(config.embed_dim, config.vocab_size) for i in range(config.block_size[-1]) ]) self.gradient_checkpointing = True def embed_with_model_aux(self, code, model_aux, mode=None): # mode = "visual" or "semantic" xs_emb = model_aux.get_code_emb_with_depth(code, mode=mode) return xs_emb def forward(self, embed_from_body, code, model_aux=None, mode=None): B, seq_len, D = code.shape # 59, 256, 4 depth_ctx = self.embed_with_model_aux(code, model_aux, mode=mode) depth_ctx = torch.cumsum(depth_ctx, dim=-2) depth_ctx = self.in_mlp_1(depth_ctx) # torch.Size([59, 256, 4, llm_hidden_size(2048)]) depth_ctx_full = torch.cat( [ embed_from_body.view(B, seq_len, 1, -1), depth_ctx[:, :, :-1, :], ], dim=-2, ) # B, 256, 4, 2048 depth_ctx_full = depth_ctx_full.reshape(B * seq_len, D, -1) # B*256, 4, 2048 for i, tlayer in enumerate(self.head_transformer): if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs) return custom_forward depth_ctx_full = torch.utils.checkpoint.checkpoint( create_custom_forward(tlayer), depth_ctx_full, ) else: depth_ctx_full = tlayer( depth_ctx_full, ) depth_ctx_full = self.headnorm(depth_ctx_full) # B*256, 4, 2048 logits = [head(depth_ctx_full[:, i]) for i, head in enumerate(self.heads)] # logits[0].shape = B*256, 16384(codebook_size) return logits def generate(self, embed_from_body, model_aux=None, cfg=3.0, mode=None): top_k = 900 # top_k = 500 # top_k = 1 top_p = 0.96 # top_p = 0.99 B, seq_len, _ = embed_from_body.shape # 1, 1, 2048 next_token_ids = torch.zeros(B, 1, 4, dtype=torch.long).to(embed_from_body.device) for i in range(4): logits = self(embed_from_body, next_token_ids, model_aux, mode=mode) next_token_logits = logits[i].clone() next_token_logits = next_token_logits[B//2:, :] + cfg * (next_token_logits[:B//2, :] - next_token_logits[B//2:, :]) next_tokens = sample_from_logits(next_token_logits, temperature=1.0, top_p=top_p, top_k=top_k) next_tokens = next_tokens.reshape(B//2, seq_len).repeat(2, 1) next_token_ids[:, :, i] = next_tokens out_features = self.embed_with_model_aux(next_token_ids, model_aux, mode=mode) out_features = torch.cumsum(out_features, dim=-2)[:, :, -1, :] # out_features = self.in_mlp_1(out_features) return out_features, next_token_ids def build_projector(dim_in, dim_out, projector_type='mlp2x_gelu'): if projector_type == 'linear': linear = nn.Linear(dim_in, dim_out) linear.reset_parameters() return linear elif projector_type == 'nonlinear': linear = nn.Linear(dim_in, dim_out) linear.reset_parameters() modules = [linear, nn.GELU()] return nn.Sequential(*modules) elif projector_type == 'conv': return nn.Conv2d(dim_in, dim_out, 1) else: # mlp2x_gelu linear_1 = nn.Linear(dim_in, dim_in) linear_1.reset_parameters() modules = [linear_1] modules.append(nn.GELU()) linear_2 = nn.Linear(dim_in, dim_out) linear_2.reset_parameters() modules.append(linear_2) return nn.Sequential(*modules) class RQVAESiglipModel(PreTrainedModel): config_class = RQVAESiglipConfig def __init__(self, config: RQVAESiglipConfig): super().__init__(config) self.config = config # self.siglip_model = SiglipModel.from_pretrained(config.pretrained_model) siglip_config = SiglipModel.config_class.from_pretrained(config.pretrained_model) self.siglip_model = SiglipModel._from_config(siglip_config) # self.prequant_semantic = build_projector(config.hidden_size, config.embed_dim, projector_type='linear') self.prequant_visual = ProjectResnetBlock(in_channels=config.hidden_size, out_channels=config.embed_dim, dropout=0.0) self.prequant_visual_1 = ProjectResnetBlock(in_channels=config.hidden_size, out_channels=config.embed_dim, dropout=0.0) self.layer_norm_visual = nn.LayerNorm(config.embed_dim) self.layer_norm_semantic = nn.LayerNorm(config.embed_dim) self.quantizer_semantic = RQBottleneck( latent_shape=config.latent_shape, code_shape=config.code_shape_semantic, n_embed=config.n_embed_semantic, decay=config.decay, shared_codebook=config.shared_codebook, restart_unused_codes=config.restart_unused_codes, ) self.quantizer = RQBottleneck( latent_shape=config.latent_shape, code_shape=config.code_shape_visual, n_embed=config.n_embed_visual, decay=config.decay, shared_codebook=config.shared_codebook, restart_unused_codes=config.restart_unused_codes, ) self.postquant_semantic = build_projector(config.embed_dim, config.hidden_size, projector_type='nonlinear') self.postquant_visual = ProjectResnetBlock(in_channels=config.embed_dim, out_channels=config.hidden_size, dropout=0.0) self.post_quant_conv = PostQuantResnetBlock(in_channels=config.hidden_size, out_channels=config.ddconfig["decoder_in_channels"], dropout=0.0) self.decoder = Decoder(**config.ddconfig) try: self.decoder_latent_shape = config.decoder_latent_shape except: self.decoder_latent_shape = None def encode_text(self, text): # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions, output_hidden_states, return_dict = None, None, None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict text_model = self.siglip_model.text_model text_outputs = text_model( input_ids=text, attention_mask=None, position_ids=None, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) text_embeds = text_outputs[1] return text_embeds def encode_image(self, image): vision_model = self.siglip_model.vision_model hidden_states = vision_model.embeddings(image) attention_mask = None output_attentions = None # visual_n, semantic_n = 20, 2 # 取到倒数第n层的特征 visual_n, semantic_n = self.config.last_n_layer_recon, self.config.last_n_layer_sem for i, encoder_layer in enumerate(vision_model.encoder.layers): if vision_model.encoder.gradient_checkpointing and vision_model.encoder.training: layer_outputs = vision_model.encoder._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if i == len(vision_model.encoder.layers) - (visual_n+2): # -22 B, L, C = hidden_states.shape hidden_states_visual_1 = hidden_states.reshape(B, int(L**0.5), int(L**0.5), -1) hidden_states_visual_1 = hidden_states_visual_1.permute(0, 3, 1, 2).contiguous() hidden_states_visual_1 = self.prequant_visual_1(hidden_states_visual_1) hidden_states_visual_1 = hidden_states_visual_1.permute(0, 2, 3, 1).contiguous() if i == len(vision_model.encoder.layers) - visual_n: # -20 B, L, C = hidden_states.shape hidden_states_visual = hidden_states.reshape(B, int(L**0.5), int(L**0.5), -1) hidden_states_visual = hidden_states_visual.permute(0, 3, 1, 2).contiguous() hidden_states_visual = self.prequant_visual(hidden_states_visual) hidden_states_visual = hidden_states_visual.permute(0, 2, 3, 1).contiguous() hidden_states_visual += 0.6 * hidden_states_visual_1 hidden_states_visual = self.layer_norm_visual(hidden_states_visual) z_q_visual, quant_loss_visual, code_visual = self.quantizer(hidden_states_visual) if i == len(vision_model.encoder.layers) - semantic_n: hidden_state_26 = hidden_states B, L, C = hidden_states.shape hidden_states_semantic = hidden_states.reshape(B, int(L**0.5), int(L**0.5), -1) hidden_states_semantic = self.layer_norm_semantic(hidden_states_semantic) z_q_semantic, quant_loss_semantic, code_semantic = self.quantizer_semantic(hidden_states_semantic) return z_q_visual, code_visual, z_q_semantic, code_semantic def decode(self, z_q): z_q = z_q.permute(0, 3, 1, 2).contiguous() z_q = self.postquant_visual(z_q) if self.decoder_latent_shape is not None: z_q = F.interpolate(z_q.to(torch.float32), size=tuple(self.decoder_latent_shape), mode='bilinear').to(torch.bfloat16) z_q = self.post_quant_conv(z_q) out = self.decoder(z_q) return out @torch.no_grad() def get_code_emb_with_depth(self, code, mode=None): # 分 visual codebook 取 image embedding, mode = "visual" or "semantic" if mode == "visual": visual_embedding = self.quantizer.embed_code_with_depth(code) return visual_embedding elif mode == "semantic": semantic_embedding = self.quantizer_semantic.embed_code_with_depth(code) return semantic_embedding class RQVAESIGLIPTransformer(PreTrainedModel): config_class = RQVAESIGLIPTransformerConfig def __init__(self, config: RQVAESIGLIPTransformerConfig): super().__init__(config) rqvaesiglip_config = RQVAESiglipModel.config_class.from_dict(config.rqvaesiglip) rqtransformer_visual_config = RQTransformer.config_class.from_dict(config.rqtransformer_visual) rqtransformer_semantic_config = RQTransformer.config_class.from_dict(config.rqtransformer_semantic) self.rqvaesiglip = RQVAESiglipModel._from_config(rqvaesiglip_config) self.rqtransformer_visual = RQTransformer._from_config(rqtransformer_visual_config) self.rqtransformer_semantic = RQTransformer._from_config(rqtransformer_semantic_config) class RQVAESIGLIPTransformerVisionTower(nn.Module): def __init__(self, model_name_or_path): super().__init__() model_dtype = torch.bfloat16 self.config = RQVAESIGLIPTransformerConfig.from_pretrained(model_name_or_path) self.vision_tower = RQVAESIGLIPTransformer._from_config(self.config, torch_dtype=model_dtype) self.is_loaded = True encoder_path = self.config.rqvaesiglip["pretrained_model"] if "siglip-so400m-patch14-384" in encoder_path: # SigLIP-SO400M-patch14-384 self.image_processor = CLIPImageProcessor( size={"height": 384, "width": 384}, crop_size={"height": 384, "width": 384}, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5] ) self.image_tokens = 729 # self.config.hidden_size == 1152 elif "siglip-large-patch16-256" in encoder_path: # SigLIP-Large-patch16-256 self.image_processor = CLIPImageProcessor( size={"height": 256, "width": 256}, crop_size={"height": 256, "width": 256}, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5] ) self.image_tokens = 256 # self.config.hidden_size == 1024 else: raise NotImplementedError() def encode(self, images: torch.Tensor): vision_output = self.vision_tower.rqvaesiglip.encode_image(images) image_features_visual, tokens_visual = vision_output[0], vision_output[1] image_features_semantic, tokens_semantic = vision_output[2], vision_output[3] bs, patch_size, _, dim = image_features_visual.shape image_features_visual = torch.reshape(image_features_visual, [bs, patch_size**2, dim]) tokens_visual = torch.reshape(tokens_visual, [bs, patch_size**2, -1]) bs, patch_size, _, dim = image_features_semantic.shape image_features_semantic = torch.reshape(image_features_semantic, [bs, patch_size**2, dim]) tokens_semantic = torch.reshape(tokens_semantic, [bs, patch_size**2, -1]) return image_features_visual, image_features_semantic, tokens_visual, tokens_semantic class BaichuanVisualEncoder(RQVAESIGLIPTransformerVisionTower): def __init__(self, config): super().__init__(config) self.vision_tower.rqvaesiglip.siglip_model.vision_model.gradient_checkpointing = True # 强制开启 self.vision_tower.rqvaesiglip.siglip_model.vision_model._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint def forward( self, pixel_values # pixel_values = 179, 161, 147..., shape=13, 3, 256, 256 ): pixel_values = pixel_values.to(self.vision_tower.rqvaesiglip.siglip_model.vision_model.embeddings.patch_embedding.weight.dtype) return self.encode(pixel_values) # visual_idx (not add text_vocab) @torch.no_grad() def fake_input(self, input_ids, merge_size): fake_image = [torch.zeros([3, self.config.image_size, self.config.image_size], dtype=torch.float32, device=input_ids.device)] return fake_image # def test_vision(): # from transformers.models.clip.modeling_clip import CLIPPreTrainedModel # from transformers import AutoConfig # config = AutoConfig.from_pretrained("./", trust_remote_code=True) # ae = BaichuanVisualEncoder(config).cuda().to(torch.bfloat16) # bg = BaichuanVisualBridge(config).cuda().to(torch.bfloat16) # print(ae) # pixel_input = torch.rand([4, 3, config.image_size, config.image_size], dtype=torch.float32).cuda() # visual_embedding = ae(pixel_input)[0][:, 1:] # 删除class token # visual_proj = bg(visual_embedding) # print(visual_proj.shape) # print(ae.fake_input(visual_proj.device)) if __name__ == '__main__': fire.Fire()