|
import math |
|
import os |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from functools import partial |
|
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast |
|
|
|
import numpy as np |
|
import torch |
|
from huggingface_hub import snapshot_download |
|
from peft import PeftModel |
|
from peft.utils.hotswap import hotswap_adapter |
|
from PIL import Image |
|
from torch import nn |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
from transformers import BatchFeature |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration, |
|
Qwen2_5_VLProcessor) |
|
|
|
from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config |
|
|
|
|
|
class PromptType(str, Enum): |
|
query = "query" |
|
passage = "passage" |
|
|
|
|
|
class TaskType(str, Enum): |
|
retrieval = "retrieval" |
|
code = "code" |
|
text_matching = "text-matching" |
|
|
|
|
|
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor): |
|
def __init__(self, *args, **kwargs) -> None: |
|
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs) |
|
self.assistant_prefix_len = 58 |
|
self.text_max_length = 8192 |
|
|
|
@staticmethod |
|
def round_by_factor(number: float, factor: int) -> int: |
|
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" |
|
return round(number / factor) * factor |
|
|
|
@staticmethod |
|
def ceil_by_factor(number: float, factor: int) -> int: |
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.ceil(number / factor) * factor |
|
|
|
@staticmethod |
|
def floor_by_factor(number: float, factor: int) -> int: |
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.floor(number / factor) * factor |
|
|
|
def process_images( |
|
self, |
|
images: Union[List[Image.Image], List[List[Image.Image]]], |
|
) -> BatchFeature: |
|
|
|
if isinstance(images[0], list): |
|
images = cast(List[List[Image.Image]], images) |
|
text_doc = [] |
|
for i in range(len(images)): |
|
conversation = [ |
|
{"role": "user", "content": [{"type": "image"}] * len(images[i])} |
|
] |
|
template = self.apply_chat_template( |
|
conversation, add_generation_prompt=False |
|
) |
|
text_doc.append(template[self.assistant_prefix_len :]) |
|
|
|
else: |
|
images = cast(List[Image.Image], images) |
|
text_doc = [ |
|
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n" |
|
] * len(images) |
|
|
|
|
|
batch_doc = self(text=text_doc, images=images, padding="longest", return_tensors="pt") |
|
|
|
offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2] |
|
|
|
pixel_values = torch.split(batch_doc["pixel_values"], offsets.tolist()) |
|
|
|
max_length = max([len(pv) for pv in pixel_values]) |
|
|
|
pixel_values = [ |
|
torch.cat( |
|
[ |
|
pv, |
|
torch.zeros( |
|
(max_length - len(pv), pv.shape[1]), |
|
dtype=pv.dtype, |
|
device=pv.device, |
|
), |
|
] |
|
) |
|
for pv in pixel_values |
|
] |
|
|
|
batch_doc["pixel_values"] = torch.stack(pixel_values) |
|
return batch_doc |
|
|
|
def process_texts( |
|
self, |
|
texts: List[str], |
|
max_length: Optional[int] = None, |
|
prefix: Optional[str] = None, |
|
padding: Optional[str] = None, |
|
) -> BatchFeature: |
|
|
|
max_length = ( |
|
self.text_max_length |
|
if max_length is None |
|
else min(max_length, self.text_max_length) |
|
) |
|
padded_texts: List[str] = [] |
|
|
|
for text in texts: |
|
if prefix: |
|
text = f"{prefix}: {text}" |
|
padded_texts.append(text) |
|
|
|
text_batch = self( |
|
text=padded_texts, |
|
return_tensors="pt", |
|
padding=padding or "longest", |
|
max_length=max_length, |
|
truncation=True, |
|
) |
|
|
|
return text_batch |
|
|
|
|
|
@dataclass |
|
class JinaEmbeddingsV4ModelOutput: |
|
""" |
|
Base class for the Hybrid Model outputs. |
|
Args: |
|
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM. |
|
single_vec_emb (torch.Tensor, optional): Single-vector embeddings. |
|
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings. |
|
""" |
|
|
|
vlm_last_hidden_states: Optional[torch.Tensor] = None |
|
single_vec_emb: Optional[torch.Tensor] = None |
|
multi_vec_emb: Optional[torch.Tensor] = None |
|
|
|
|
|
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): |
|
config_class = JinaEmbeddingsV4Config |
|
main_input_name: ClassVar[str] = "doc_input_ids" |
|
|
|
def __init__(self, config: JinaEmbeddingsV4Config): |
|
Qwen2_5_VLForConditionalGeneration.__init__(self, config) |
|
self._init_projection_layers(config) |
|
self.post_init() |
|
self.processor = JinaEmbeddingsV4Processor.from_pretrained( |
|
self.name_or_path, trust_remote_code=True |
|
) |
|
self.single_vector_projector_dim = config.single_vector_projector_dim |
|
self.multi_vector_projector_dim = config.multi_vector_projector_dim |
|
|
|
def get_last_hidden_states( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.Tensor, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
if "pixel_values" in kwargs: |
|
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] |
|
kwargs["pixel_values"] = torch.cat( |
|
[pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0 |
|
) |
|
|
|
position_ids, rope_deltas = super().get_rope_index( |
|
input_ids=input_ids, |
|
image_grid_thw=kwargs.get("image_grid_thw", None), |
|
attention_mask=attention_mask, |
|
) |
|
|
|
kwargs["output_hidden_states"] = True |
|
|
|
outputs = super().forward( |
|
input_ids, |
|
attention_mask, |
|
**kwargs, |
|
position_ids=position_ids, |
|
rope_deltas=rope_deltas, |
|
use_cache=False, |
|
) |
|
|
|
hidden_states = outputs.hidden_states |
|
if not hidden_states: |
|
raise ValueError("Hidden states not found in model output") |
|
|
|
return hidden_states[-1] |
|
|
|
def _init_projection_layers(self, config) -> None: |
|
""" |
|
Initializes projection layers. |
|
""" |
|
self.config.single_vector_projector_dim = config.single_vector_projector_dim |
|
self.config.multi_vector_projector_dim = config.multi_vector_projector_dim |
|
|
|
self.single_vector_projector = nn.Linear( |
|
in_features=self.config.hidden_size, |
|
out_features=self.config.single_vector_projector_dim, |
|
) |
|
|
|
self.multi_vector_projector = nn.Linear( |
|
in_features=self.config.hidden_size, |
|
out_features=self.config.multi_vector_projector_dim, |
|
) |
|
|
|
def project_to_single_vector_embeddings( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Project the hidden states to single-vector embeddings. |
|
""" |
|
if self._input_has_image(input_ids[0]): |
|
img_start_pos = torch.where( |
|
input_ids[0] == self.config.vision_start_token_id |
|
)[0][0] |
|
img_end_pos = torch.where(input_ids[0] == self.config.vision_end_token_id)[ |
|
0 |
|
][0] |
|
pooled_output = ( |
|
hidden_states[0][img_start_pos : img_end_pos + 1] |
|
.mean(dim=0) |
|
.unsqueeze(0) |
|
) |
|
|
|
else: |
|
pooled_output = torch.sum( |
|
hidden_states * attention_mask.unsqueeze(-1), dim=1 |
|
) / torch.sum(attention_mask, dim=1, keepdim=True) |
|
single_vec_emb = self.single_vector_projector(pooled_output) |
|
return torch.nn.functional.normalize(single_vec_emb, dim=-1) |
|
|
|
def project_to_multi_vector_embeddings( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Project the hidden states to multi-vector embeddings. |
|
""" |
|
multi_vec_emb = self.multi_vector_projector(hidden_states) |
|
multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1) |
|
return multi_vec_emb * attention_mask.unsqueeze(-1) |
|
|
|
def _input_has_image(self, input_ids): |
|
return self.config.vision_start_token_id in input_ids |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.Tensor, |
|
output_vlm_last_hidden_states: bool = False, |
|
**kwargs, |
|
) -> JinaEmbeddingsV4ModelOutput: |
|
""" |
|
Forward pass through QwenVL25Embeddings. Returns both single-vector and multi-vector embeddings. |
|
Args: |
|
input_ids (torch.LongTensor): The input tokens tensor. |
|
attention_mask (torch.LongTensor): The attention mask tensor. |
|
Returns: |
|
JinaEmbeddingsV4ModelOutput: |
|
single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim). |
|
multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim). |
|
""" |
|
|
|
hidden_states = self.get_last_hidden_states( |
|
input_ids=input_ids, attention_mask=attention_mask, **kwargs |
|
) |
|
|
|
|
|
single_vec_emb = self.project_to_single_vector_embeddings( |
|
hidden_states, attention_mask, input_ids=input_ids |
|
) |
|
multi_vec_emb = self.project_to_multi_vector_embeddings( |
|
hidden_states, attention_mask |
|
) |
|
|
|
return JinaEmbeddingsV4ModelOutput( |
|
vlm_last_hidden_states=( |
|
hidden_states if output_vlm_last_hidden_states else None |
|
), |
|
single_vec_emb=single_vec_emb, |
|
multi_vec_emb=multi_vec_emb, |
|
) |
|
|
|
def _process_batches( |
|
self, |
|
data: List[Union[str, Image.Image]], |
|
processor_fn: Callable, |
|
desc: str, |
|
vector_type: Optional[str] = None, |
|
return_numpy: bool = False, |
|
**kwargs, |
|
) -> Union[np.ndarray, List[torch.Tensor]]: |
|
dataloader = DataLoader( |
|
dataset=data, |
|
batch_size=kwargs.get("batch_size", 32), |
|
shuffle=False, |
|
collate_fn=processor_fn, |
|
) |
|
vector_type = vector_type or "single_vector" |
|
results = [] |
|
self.eval() |
|
for batch in tqdm(dataloader, desc=desc): |
|
with torch.no_grad(): |
|
batch = {k: v.to(self.device) for k, v in batch.items()} |
|
with torch.autocast(device_type=torch.device(self.device).type): |
|
embeddings = self(**batch) |
|
if vector_type == "single_vector": |
|
embeddings = embeddings.single_vec_emb |
|
else: |
|
embeddings = embeddings.multi_vec_emb |
|
results.append( |
|
embeddings.cpu() |
|
if return_numpy |
|
else list(torch.unbind(embeddings)) |
|
) |
|
if return_numpy: |
|
return np.concatenate([result.numpy() for result in results], axis=0) |
|
return [item for sublist in results for item in sublist] |
|
|
|
def encode_texts( |
|
self, |
|
queries: List[str], |
|
max_length: int = 8192, |
|
batch_size: int = 8, |
|
vector_type: Optional[str] = None, |
|
desc: Optional[str] = None, |
|
**kwargs, |
|
) -> List[torch.Tensor]: |
|
processor_fn = partial( |
|
self.processor.process_texts, max_length=max_length, prefix="Query" |
|
) |
|
return self._process_batches( |
|
data=queries, |
|
processor_fn=processor_fn, |
|
desc=desc or "Encode queries...", |
|
vector_type=vector_type, |
|
batch_size=batch_size, |
|
**kwargs, |
|
) |
|
|
|
def encode_images( |
|
self, |
|
documents: List[Image.Image], |
|
batch_size: int = 8, |
|
vector_type: Optional[str] = None, |
|
desc: Optional[str] = None, |
|
**kwargs, |
|
) -> List[torch.Tensor]: |
|
return self._process_batches( |
|
data=documents, |
|
processor_fn=self.processor.process_images, |
|
desc=desc or "Encode documents...", |
|
vector_type=vector_type, |
|
batch_size=batch_size, |
|
**kwargs, |
|
) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path, |
|
*args, |
|
**kwargs, |
|
): |
|
if "torch_dtype" not in kwargs: |
|
kwargs["torch_dtype"] = "auto" |
|
|
|
task = kwargs.pop("task", TaskType.retrieval) |
|
|
|
|
|
base_model = super().from_pretrained( |
|
pretrained_model_name_or_path, *args, **kwargs |
|
) |
|
|
|
|
|
if os.path.isdir(base_model.name_or_path): |
|
adapter_dir = os.path.join(base_model.name_or_path, "adapters") |
|
else: |
|
adapter_cache_path = snapshot_download( |
|
repo_id=base_model.name_or_path, allow_patterns=["adapters/*"] |
|
) |
|
adapter_dir = os.path.join(adapter_cache_path, "adapters") |
|
|
|
|
|
base_model.adapter_dir = adapter_dir |
|
|
|
|
|
peft_model = PeftModel.from_pretrained( |
|
base_model, os.path.join(adapter_dir, task) |
|
) |
|
|
|
|
|
def set_task_method(self, task_name: Union[str, TaskType]): |
|
""" |
|
Set the task adapter for the model. |
|
|
|
Args: |
|
task_name (Union[str, TaskType]): The task name. Must be one of TaskType values or |
|
one of ['retrieval', 'text-matching', 'code'] |
|
""" |
|
if isinstance(task_name, str): |
|
try: |
|
task_name = TaskType(task_name) |
|
except ValueError: |
|
valid_tasks = [t.value for t in TaskType] |
|
raise ValueError( |
|
f"Invalid task: {task_name}. Must be one of {valid_tasks}" |
|
) |
|
|
|
adapter_path = os.path.join(self.adapter_dir, task_name.value) |
|
hotswap_adapter(self, adapter_path, adapter_name="default") |
|
|
|
|
|
peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model)) |
|
|
|
return peft_model |
|
|