jupyterjazz commited on
Commit
9624180
·
1 Parent(s): 4453d02

refactor: support urls, fast processor, flash attn check

Browse files
modeling_jina_embeddings_v4.py CHANGED
@@ -5,20 +5,24 @@ import os
5
  from dataclasses import dataclass
6
  from enum import Enum
7
  from functools import partial
 
8
  from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
9
 
10
  import numpy as np
 
11
  import torch
12
  from huggingface_hub import snapshot_download
13
- from peft import PeftModel, LoraConfig
14
  from PIL import Image
15
  from torch import nn
16
  from torch.utils.data import DataLoader
17
  from tqdm import tqdm
18
  from transformers import BatchFeature
19
- from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
 
20
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
21
  from .custom_lora_module import MultiAdapterLinear
 
22
 
23
 
24
  class PromptType(str, Enum):
@@ -140,7 +144,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
140
  self._init_projection_layers(config)
141
  self.post_init()
142
  self.processor = JinaEmbeddingsV4Processor.from_pretrained(
143
- self.name_or_path, trust_remote_code=True
144
  )
145
  self.single_vector_projector_dim = config.single_vector_projector_dim
146
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
@@ -160,7 +164,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
160
  task (str): The task name. Must be one of ['retrieval', 'text-matching', 'code']
161
  """
162
  if task not in self.config.task_names:
163
- raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
 
 
164
  self._task = task
165
 
166
  def get_last_hidden_states(
@@ -342,7 +348,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
342
  for batch in tqdm(dataloader, desc=desc):
343
  with torch.no_grad():
344
  batch = {k: v.to(self.device) for k, v in batch.items()}
345
- with torch.autocast(device_type=torch.device(self.device).type, dtype=torch.bfloat16):
 
 
346
  embeddings = self(**batch, task_label=task_label)
347
  if vector_type == "single_vector":
348
  embeddings = embeddings.single_vec_emb
@@ -395,7 +403,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
395
  encode_kwargs["truncate_dim"] = truncate_dim
396
 
397
  return encode_kwargs
398
-
399
  def _validate_task(self, task: Optional[str] = None) -> str:
400
  if task is None:
401
  if self.task is None:
@@ -406,7 +414,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
406
  task = self.task
407
  else:
408
  if task not in self.config.task_names:
409
- raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
 
 
410
  return task
411
 
412
  def encode_texts(
@@ -460,9 +470,23 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
460
 
461
  return embeddings
462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  def encode_images(
464
  self,
465
- images: List[Image.Image],
466
  task: Optional[str] = None,
467
  batch_size: int = 8,
468
  vector_type: Optional[str] = None,
@@ -474,7 +498,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
474
  Encodes a list of images into embeddings.
475
 
476
  Args:
477
- images: List of PIL images to encode
478
  batch_size: Number of images to process at once
479
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
480
  return_numpy: Whether to return numpy arrays instead of torch tensors
@@ -489,9 +513,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
489
  self.processor.image_processor.max_pixels = (
490
  max_pixels # change during encoding
491
  )
492
-
493
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
494
  task = self._validate_task(task)
 
495
  embeddings = self._process_batches(
496
  data=images,
497
  processor_fn=self.processor.process_images,
@@ -519,8 +543,10 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
519
  """
520
  if "torch_dtype" not in kwargs:
521
  kwargs["torch_dtype"] = "auto"
522
-
523
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
 
 
524
 
525
  base_model = super().from_pretrained(
526
  pretrained_model_name_or_path, *args, **kwargs
@@ -547,19 +573,19 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
547
  model_id=adapter_dir,
548
  config=lora_config,
549
  )
550
-
551
  @property
552
  def task(self):
553
  return self.model.task
554
-
555
  @task.setter
556
  def task(self, value):
557
  self.model.task = value
558
-
559
  peft_model.task = property(task.fget, task.fset)
560
  peft_model.__class__.task = property(
561
  lambda self: self.model.task,
562
- lambda self, value: setattr(self.model, 'task', value)
563
  )
564
 
565
  return peft_model
 
5
  from dataclasses import dataclass
6
  from enum import Enum
7
  from functools import partial
8
+ from io import BytesIO
9
  from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
10
 
11
  import numpy as np
12
+ import requests
13
  import torch
14
  from huggingface_hub import snapshot_download
15
+ from peft import LoraConfig, PeftModel
16
  from PIL import Image
17
  from torch import nn
18
  from torch.utils.data import DataLoader
19
  from tqdm import tqdm
20
  from transformers import BatchFeature
21
+ from transformers.utils import is_flash_attn_2_available
22
+
23
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
24
  from .custom_lora_module import MultiAdapterLinear
25
+ from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
26
 
27
 
28
  class PromptType(str, Enum):
 
144
  self._init_projection_layers(config)
145
  self.post_init()
146
  self.processor = JinaEmbeddingsV4Processor.from_pretrained(
147
+ self.name_or_path, trust_remote_code=True, use_fast=True
148
  )
149
  self.single_vector_projector_dim = config.single_vector_projector_dim
150
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
 
164
  task (str): The task name. Must be one of ['retrieval', 'text-matching', 'code']
165
  """
166
  if task not in self.config.task_names:
167
+ raise ValueError(
168
+ f"Invalid task: {task}. Must be one of {self.config.task_names}."
169
+ )
170
  self._task = task
171
 
172
  def get_last_hidden_states(
 
348
  for batch in tqdm(dataloader, desc=desc):
349
  with torch.no_grad():
350
  batch = {k: v.to(self.device) for k, v in batch.items()}
351
+ with torch.autocast(
352
+ device_type=torch.device(self.device).type, dtype=torch.bfloat16
353
+ ):
354
  embeddings = self(**batch, task_label=task_label)
355
  if vector_type == "single_vector":
356
  embeddings = embeddings.single_vec_emb
 
403
  encode_kwargs["truncate_dim"] = truncate_dim
404
 
405
  return encode_kwargs
406
+
407
  def _validate_task(self, task: Optional[str] = None) -> str:
408
  if task is None:
409
  if self.task is None:
 
414
  task = self.task
415
  else:
416
  if task not in self.config.task_names:
417
+ raise ValueError(
418
+ f"Invalid task: {task}. Must be one of {self.config.task_names}."
419
+ )
420
  return task
421
 
422
  def encode_texts(
 
470
 
471
  return embeddings
472
 
473
+ def _load_images_if_needed(
474
+ self, images: List[Union[str, Image.Image]]
475
+ ) -> List[Image.Image]:
476
+ loaded_images = []
477
+ for image in images:
478
+ if isinstance(image, str):
479
+ if image.startswith("http"):
480
+ response = requests.get(image)
481
+ image = Image.open(BytesIO(response.content)).convert("RGB")
482
+ else:
483
+ image = Image.open(image).convert("RGB")
484
+ loaded_images.append(image)
485
+ return loaded_images
486
+
487
  def encode_images(
488
  self,
489
+ images: List[Union[str, Image.Image]],
490
  task: Optional[str] = None,
491
  batch_size: int = 8,
492
  vector_type: Optional[str] = None,
 
498
  Encodes a list of images into embeddings.
499
 
500
  Args:
501
+ images: List of PIL images, URLs, or local file paths to encode
502
  batch_size: Number of images to process at once
503
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
504
  return_numpy: Whether to return numpy arrays instead of torch tensors
 
513
  self.processor.image_processor.max_pixels = (
514
  max_pixels # change during encoding
515
  )
 
516
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
517
  task = self._validate_task(task)
518
+ images = self._load_images_if_needed(images)
519
  embeddings = self._process_batches(
520
  data=images,
521
  processor_fn=self.processor.process_images,
 
543
  """
544
  if "torch_dtype" not in kwargs:
545
  kwargs["torch_dtype"] = "auto"
546
+
547
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
548
+ if not is_flash_attn_2_available():
549
+ kwargs["attn_implementation"] = "sdpa"
550
 
551
  base_model = super().from_pretrained(
552
  pretrained_model_name_or_path, *args, **kwargs
 
573
  model_id=adapter_dir,
574
  config=lora_config,
575
  )
576
+
577
  @property
578
  def task(self):
579
  return self.model.task
580
+
581
  @task.setter
582
  def task(self, value):
583
  self.model.task = value
584
+
585
  peft_model.task = property(task.fget, task.fset)
586
  peft_model.__class__.task = property(
587
  lambda self: self.model.task,
588
+ lambda self, value: setattr(self.model, "task", value),
589
  )
590
 
591
  return peft_model
tokenizer_config.json CHANGED
@@ -202,7 +202,7 @@
202
  "extra_special_tokens": {},
203
  "model_max_length": 131072,
204
  "pad_token": "<|endoftext|>",
205
- "processor_class": "ColQwen25DuoProcessor",
206
  "split_special_tokens": false,
207
  "tokenizer_class": "Qwen2Tokenizer",
208
  "unk_token": null
 
202
  "extra_special_tokens": {},
203
  "model_max_length": 131072,
204
  "pad_token": "<|endoftext|>",
205
+ "processor_class": "JinaEmbeddingsV4Processor",
206
  "split_special_tokens": false,
207
  "tokenizer_class": "Qwen2Tokenizer",
208
  "unk_token": null