jupyterjazz commited on
Commit
8a9e9ed
·
1 Parent(s): 85f64e2

feat: finalized implementation

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (4) hide show
  1. config.json +3 -1
  2. custom_lora_module.py +73 -197
  3. modeling_jina_embeddings_v4.py +112 -76
  4. qwen2_5_vl.py +18 -85
config.json CHANGED
@@ -54,5 +54,7 @@
54
  "vision_start_token_id": 151652,
55
  "vision_token_id": 151654,
56
  "vocab_size": 151936,
57
- "truncate_dim": null
 
 
58
  }
 
54
  "vision_start_token_id": 151652,
55
  "vision_token_id": 151654,
56
  "vocab_size": 151936,
57
+ "truncate_dim": null,
58
+ "task_names": ["retrieval", "text-matching", "code"],
59
+ "matryoshka_dims": [128, 256, 512, 1024]
60
  }
custom_lora_module.py CHANGED
@@ -2,31 +2,35 @@ from __future__ import annotations
2
 
3
  import math
4
  import warnings
5
- from typing import Any, Optional, Union
6
 
7
  import torch
8
  import torch.nn as nn
9
- import torch.nn.functional as F
10
- from accelerate.utils.imports import is_xpu_available
11
- from torch import svd_lowrank
12
- from transformers.pytorch_utils import Conv1D
13
 
14
- from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
15
- from peft.utils.integrations import (
16
- dequantize_module_weight,
17
- gather_params_ctx,
18
- get_bnb_param_type,
19
- skip_init_on_device,
20
- )
21
- from peft.utils.other import transpose
22
  from peft.tuners.lora import LoraLayer
23
 
24
- class Linear(nn.Module, LoraLayer):
25
- # Lora implemented in a dense layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def __init__(
27
  self,
28
  base_layer,
29
  adapter_name: str,
 
30
  r: int = 0,
31
  lora_alpha: int = 1,
32
  lora_dropout: float = 0.0,
@@ -40,8 +44,9 @@ class Linear(nn.Module, LoraLayer):
40
  ) -> None:
41
  super().__init__()
42
  LoraLayer.__init__(self, base_layer, **kwargs)
43
- self.fan_in_fan_out = fan_in_fan_out
44
 
 
 
45
  self._active_adapter = adapter_name
46
  self.update_layer(
47
  adapter_name,
@@ -55,160 +60,14 @@ class Linear(nn.Module, LoraLayer):
55
  )
56
  self.is_target_conv_1d_layer = is_target_conv_1d_layer
57
 
58
- def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
59
- """
60
- Merge the active adapter weights into the base weights
61
-
62
- Args:
63
- safe_merge (`bool`, *optional*):
64
- If True, the merge operation will be performed in a copy of the original weights and check for NaNs
65
- before merging the weights. This is useful if you want to check if the merge operation will produce
66
- NaNs. Defaults to `False`.
67
- adapter_names (`list[str]`, *optional*):
68
- The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
69
- to `None`.
70
- """
71
- adapter_names = check_adapters_to_merge(self, adapter_names)
72
- if not adapter_names:
73
- # no adapter to merge
74
- return
75
-
76
- for active_adapter in adapter_names:
77
- if active_adapter in self.lora_A.keys():
78
- base_layer = self.get_base_layer()
79
- if safe_merge:
80
- # Note that safe_merge will be slower than the normal merge
81
- # because of the copy operation.
82
- orig_weights = base_layer.weight.data.clone()
83
- delta_weight = self.get_delta_weight(active_adapter)
84
- if not self.use_dora[active_adapter]:
85
- orig_weights += delta_weight
86
- else:
87
- # handle dora
88
- # since delta_weight already includes scaling, set it to 1 here
89
- weight_norm = (
90
- self.lora_magnitude_vector[active_adapter]
91
- .get_weight_norm(orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1)
92
- .detach()
93
- )
94
- # We need to cache weight_norm because it has to be based on the original weights. We
95
- # cannot calculate it on the fly based on the merged weights when unmerging because its a
96
- # different value
97
- self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
98
- dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
99
- dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
100
- orig_weights = dora_factor * (orig_weights + delta_weight)
101
-
102
- if not torch.isfinite(orig_weights).all():
103
- raise ValueError(
104
- f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
105
- )
106
-
107
- base_layer.weight.data = orig_weights
108
-
109
- if self.lora_bias[active_adapter]:
110
- new_bias = base_layer.bias + self.lora_B[active_adapter].bias
111
- if not torch.isfinite(new_bias).all():
112
- raise ValueError(
113
- f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
114
- )
115
- base_layer.bias.data = new_bias
116
-
117
- else:
118
- delta_weight = self.get_delta_weight(active_adapter)
119
- if not self.use_dora[active_adapter]:
120
- base_layer.weight.data += delta_weight
121
- else:
122
- # handle dora
123
- # since delta_weight already includes scaling, set it to 1 here
124
- weight_norm = (
125
- self.lora_magnitude_vector[active_adapter]
126
- .get_weight_norm(
127
- base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1
128
- )
129
- .detach()
130
- )
131
- # We need to cache weight_norm because it has to be based on the original weights. We
132
- # cannot calculate it on the fly based on the merged weights when unmerging because its a
133
- # different value
134
- self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
135
- dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
136
- dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
137
- new_weight = dora_factor * (base_layer.weight.data + delta_weight)
138
- base_layer.weight.data = new_weight
139
-
140
- if self.lora_bias[active_adapter]:
141
- base_layer.bias.data += self.lora_B[active_adapter].bias
142
-
143
- self.merged_adapters.append(active_adapter)
144
-
145
- def unmerge(self) -> None:
146
- """
147
- This method unmerges all merged adapter layers from the base weights.
148
- """
149
- if not self.merged:
150
- warnings.warn("Already unmerged. Nothing to do.")
151
- return
152
- while len(self.merged_adapters) > 0:
153
- active_adapter = self.merged_adapters.pop()
154
- if active_adapter in self.lora_A.keys():
155
- weight = self.get_base_layer().weight
156
- delta_weight = self.get_delta_weight(active_adapter)
157
- if not self.use_dora[active_adapter]:
158
- weight.data -= delta_weight
159
- else:
160
- weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
161
- dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
162
- weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight
163
- weight.data = weight_orig
164
-
165
- if self.lora_bias[active_adapter]:
166
- self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias
167
 
168
- def get_delta_weight(self, adapter) -> torch.Tensor:
169
- """
170
- Compute the delta weight for the given adapter.
171
-
172
- Args:
173
- adapter (str):
174
- The name of the adapter for which the delta weight should be computed.
175
- """
176
- device = self.lora_B[adapter].weight.device
177
- dtype = self.lora_B[adapter].weight.dtype
178
-
179
- # In case users wants to merge the adapter weights that are in
180
- # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
181
- # (b)float16 because some CPUs have slow bf16/fp16 matmuls.
182
- cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
183
-
184
- weight_A = self.lora_A[adapter].weight
185
- weight_B = self.lora_B[adapter].weight
186
-
187
- if cast_to_fp32:
188
- weight_A = weight_A.float()
189
- weight_B = weight_B.float()
190
-
191
- output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
192
-
193
- if cast_to_fp32:
194
- output_tensor = output_tensor.to(dtype=dtype)
195
-
196
- # cast back the weights
197
- self.lora_A[adapter].weight.data = weight_A.to(dtype)
198
- self.lora_B[adapter].weight.data = weight_B.to(dtype)
199
-
200
- return output_tensor
201
-
202
- def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
203
  self._check_forward_args(x, *args, **kwargs)
204
- adapter_names = kwargs.pop("adapter_names", None)
205
 
206
  if self.disable_adapters:
207
  if self.merged:
208
  self.unmerge()
209
  result = self.base_layer(x, *args, **kwargs)
210
- elif adapter_names is not None:
211
- result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
212
  elif self.merged:
213
  result = self.base_layer(x, *args, **kwargs)
214
  else:
@@ -219,30 +78,34 @@ class Linear(nn.Module, LoraLayer):
219
  for active_adapter in self.active_adapters:
220
  if active_adapter not in lora_A_keys:
221
  continue
222
-
223
- lora_A = self.lora_A[active_adapter]['default']
224
- lora_B = self.lora_B[active_adapter]['default']
225
- dropout = self.lora_dropout[active_adapter]
226
- scaling = self.scaling[active_adapter]
227
- x = self._cast_input_dtype(x, lora_A.weight.dtype)
228
-
229
- if not self.use_dora[active_adapter]:
230
  result = result + lora_B(lora_A(dropout(x))) * scaling
231
  else:
232
- if isinstance(dropout, nn.Identity) or not self.training:
233
- base_result = result
234
- else:
235
- x = dropout(x)
236
- base_result = None
237
-
238
- result = result + self.lora_magnitude_vector[active_adapter](
239
- x,
240
- lora_A=lora_A,
241
- lora_B=lora_B,
242
- scaling=scaling,
243
- base_layer=self.get_base_layer(),
244
- base_result=base_result,
245
- )
 
 
 
 
 
246
 
247
  result = result.to(torch_result_dtype)
248
 
@@ -278,12 +141,12 @@ class Linear(nn.Module, LoraLayer):
278
  self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
279
  # Actual trainable parameters
280
  self.lora_A[adapter_name] = nn.ModuleDict({
281
- "default": nn.Linear(self.in_features, r, bias=False),
282
- "second_adapter": nn.Linear(self.in_features, r, bias=False)
283
  })
284
  self.lora_B[adapter_name] = nn.ModuleDict({
285
- "default": nn.Linear(r, self.out_features, bias=lora_bias),
286
- "second_adapter": nn.Linear(r, self.out_features, bias=lora_bias)
287
  })
288
  self.lora_bias[adapter_name] = lora_bias
289
 
@@ -303,15 +166,28 @@ class Linear(nn.Module, LoraLayer):
303
  if init_lora_weights is True:
304
  # initialize A the same way as the default for nn.Linear and B to zero
305
  # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
306
- nn.init.kaiming_uniform_(self.lora_A[adapter_name]['default'].weight, a=math.sqrt(5))
307
- nn.init.kaiming_uniform_(self.lora_A[adapter_name]['second_adapter'].weight, a=math.sqrt(5))
308
  elif init_lora_weights.lower() == "gaussian":
309
- nn.init.normal_(self.lora_A[adapter_name]['default'].weight, std=1 / self.r[adapter_name])
310
- nn.init.normal_(self.lora_A[adapter_name]['second_adapter'].weight, std=1 / self.r[adapter_name])
311
  else:
312
  raise ValueError(f"Unknown initialization {init_lora_weights=}")
313
- nn.init.zeros_(self.lora_B[adapter_name]['default'].weight)
314
- nn.init.zeros_(self.lora_B[adapter_name]['second_adapter'].weight)
315
  if self.lora_bias[adapter_name]:
316
- nn.init.zeros_(self.lora_B[adapter_name]['default'].bias)
317
- nn.init.zeros_(self.lora_B[adapter_name]['second_adapter'].bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import math
4
  import warnings
5
+ from typing import Any, Optional, Union, List
6
 
7
  import torch
8
  import torch.nn as nn
 
 
 
 
9
 
 
 
 
 
 
 
 
 
10
  from peft.tuners.lora import LoraLayer
11
 
12
+ class MultiAdapterLinear(nn.Module, LoraLayer):
13
+ """
14
+ Custom LoRA module supporting multiple adapters for a linear layer.
15
+
16
+ This module extends the standard LoRA implementation to support multiple task-specific
17
+ adapters that can be dynamically selected during the forward pass. The task_label
18
+ parameter passed to the forward function determines which LoRA adapter(s) to use:
19
+ - If task_label is a string, all examples in the batch use the same adapter
20
+ - If task_label is a list of strings, each example can use a different adapter
21
+
22
+ This enables efficient multi-task inference where all task-specific LoRA adapters
23
+ are loaded in memory simultaneously and dynamically selected per example, eliminating
24
+ the need to switch adapter states between tasks and allowing optimal throughput
25
+ for mixed-task batches.
26
+
27
+ Derived from peft.tuners.lora.Linear.
28
+ """
29
  def __init__(
30
  self,
31
  base_layer,
32
  adapter_name: str,
33
+ task_names: List[str],
34
  r: int = 0,
35
  lora_alpha: int = 1,
36
  lora_dropout: float = 0.0,
 
44
  ) -> None:
45
  super().__init__()
46
  LoraLayer.__init__(self, base_layer, **kwargs)
 
47
 
48
+ self.fan_in_fan_out = fan_in_fan_out
49
+ self.task_names = task_names
50
  self._active_adapter = adapter_name
51
  self.update_layer(
52
  adapter_name,
 
60
  )
61
  self.is_target_conv_1d_layer = is_target_conv_1d_layer
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ def forward(self, x: torch.Tensor, task_label: Union[str, List[str]], *args: Any, **kwargs: Any) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  self._check_forward_args(x, *args, **kwargs)
 
66
 
67
  if self.disable_adapters:
68
  if self.merged:
69
  self.unmerge()
70
  result = self.base_layer(x, *args, **kwargs)
 
 
71
  elif self.merged:
72
  result = self.base_layer(x, *args, **kwargs)
73
  else:
 
78
  for active_adapter in self.active_adapters:
79
  if active_adapter not in lora_A_keys:
80
  continue
81
+
82
+ if isinstance(task_label, str):
83
+ lora_A = self.lora_A[active_adapter][task_label]
84
+ lora_B = self.lora_B[active_adapter][task_label]
85
+ dropout = self.lora_dropout[active_adapter]
86
+ scaling = self.scaling[active_adapter]
87
+ x = self._cast_input_dtype(x, lora_A.weight.dtype)
 
88
  result = result + lora_B(lora_A(dropout(x))) * scaling
89
  else:
90
+ unique_tasks = list(set(task_label))
91
+ lora_output = torch.zeros_like(result)
92
+
93
+ for task in unique_tasks:
94
+ task_indices = [i for i, t in enumerate(task_label) if t == task]
95
+ task_x = x[task_indices]
96
+
97
+ lora_A = self.lora_A[active_adapter][task]
98
+ lora_B = self.lora_B[active_adapter][task]
99
+ dropout = self.lora_dropout[active_adapter]
100
+ scaling = self.scaling[active_adapter]
101
+
102
+ task_x = self._cast_input_dtype(task_x, lora_A.weight.dtype)
103
+ task_lora_value = lora_B(lora_A(dropout(task_x))) * scaling
104
+
105
+ for i, idx in enumerate(task_indices):
106
+ lora_output[idx] = task_lora_value[i]
107
+
108
+ result = result + lora_output
109
 
110
  result = result.to(torch_result_dtype)
111
 
 
141
  self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
142
  # Actual trainable parameters
143
  self.lora_A[adapter_name] = nn.ModuleDict({
144
+ task_name: nn.Linear(self.in_features, r, bias=False)
145
+ for task_name in self.task_names
146
  })
147
  self.lora_B[adapter_name] = nn.ModuleDict({
148
+ task_name: nn.Linear(r, self.out_features, bias=lora_bias)
149
+ for task_name in self.task_names
150
  })
151
  self.lora_bias[adapter_name] = lora_bias
152
 
 
166
  if init_lora_weights is True:
167
  # initialize A the same way as the default for nn.Linear and B to zero
168
  # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
169
+ for task_name in self.task_names:
170
+ nn.init.kaiming_uniform_(self.lora_A[adapter_name][task_name].weight, a=math.sqrt(5))
171
  elif init_lora_weights.lower() == "gaussian":
172
+ for task_name in self.task_names:
173
+ nn.init.normal_(self.lora_A[adapter_name][task_name].weight, std=1 / self.r[adapter_name])
174
  else:
175
  raise ValueError(f"Unknown initialization {init_lora_weights=}")
176
+ for task_name in self.task_names:
177
+ nn.init.zeros_(self.lora_B[adapter_name][task_name].weight)
178
  if self.lora_bias[adapter_name]:
179
+ for task_name in self.task_names:
180
+ nn.init.zeros_(self.lora_B[adapter_name][task_name].bias)
181
+
182
+
183
+ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
184
+ """
185
+ Merge the active adapter weights into the base weights
186
+ """
187
+ raise NotImplementedError("Merge operation is not supported")
188
+
189
+ def unmerge(self) -> None:
190
+ """
191
+ This method unmerges all merged adapter layers from the base weights.
192
+ """
193
+ raise NotImplementedError("Unmerge operation is not supported")
modeling_jina_embeddings_v4.py CHANGED
@@ -20,22 +20,15 @@ from transformers import BatchFeature
20
  from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
21
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
22
  import peft
23
- from .custom_lora_module import Linear
 
24
 
25
  class PromptType(str, Enum):
26
  query = "query"
27
  passage = "passage"
28
 
29
 
30
- class TaskType(str, Enum):
31
- retrieval = "retrieval"
32
- code = "code"
33
- text_matching = "text-matching"
34
- test = "test"
35
-
36
-
37
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
38
- TRUNCATE_DIMS = [128, 256, 512, 1024]
39
  VECTOR_TYPES = ["single_vector", "multi_vector"]
40
 
41
 
@@ -153,9 +146,28 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
153
  )
154
  self.single_vector_projector_dim = config.single_vector_projector_dim
155
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  def get_last_hidden_states(
158
  self,
 
159
  input_ids: torch.LongTensor,
160
  attention_mask: torch.Tensor,
161
  **kwargs,
@@ -174,8 +186,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
174
 
175
  kwargs["output_hidden_states"] = True
176
  outputs = super().forward(
177
- input_ids,
178
- attention_mask,
 
179
  **kwargs,
180
  position_ids=position_ids,
181
  rope_deltas=rope_deltas,
@@ -207,6 +220,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
207
 
208
  def project_to_single_vector_embeddings(
209
  self,
 
210
  hidden_states: torch.Tensor,
211
  attention_mask: torch.Tensor,
212
  input_ids: Optional[torch.LongTensor] = None,
@@ -215,33 +229,48 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
215
  Project the hidden states to single-vector embeddings.
216
  """
217
  if self._input_has_image(input_ids[0]): # got document image
218
- img_start_positions = torch.where(input_ids == self.config.vision_start_token_id)[1]
219
- img_end_positions = torch.where(input_ids == self.config.vision_end_token_id)[1]
220
-
 
 
 
 
221
  batch_size, seq_len = input_ids.shape
222
- position_indices = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
223
- image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (position_indices <= img_end_positions.unsqueeze(1))
224
-
 
 
 
 
225
  masked_hidden_states = hidden_states * image_mask.unsqueeze(-1)
226
- pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(dim=1, keepdim=True)
 
 
227
 
228
  else: # got query text
229
  pooled_output = torch.sum(
230
  hidden_states * attention_mask.unsqueeze(-1), dim=1
231
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
232
 
233
- single_vec_emb = self.single_vector_projector(pooled_output)
 
 
234
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
235
 
236
  def project_to_multi_vector_embeddings(
237
  self,
 
238
  hidden_states: torch.Tensor,
239
  attention_mask: torch.Tensor,
240
  ) -> torch.Tensor:
241
  """
242
  Project the hidden states to multi-vector embeddings.
243
  """
244
- multi_vec_emb = self.multi_vector_projector(hidden_states)
 
 
245
  multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1)
246
  return multi_vec_emb * attention_mask.unsqueeze(-1)
247
 
@@ -250,6 +279,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
250
 
251
  def forward(
252
  self,
 
253
  input_ids: torch.LongTensor,
254
  attention_mask: torch.Tensor,
255
  output_vlm_last_hidden_states: bool = False,
@@ -267,14 +297,22 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
267
  """
268
  # Forward pass through the VLM
269
  hidden_states = self.get_last_hidden_states(
270
- input_ids=input_ids, attention_mask=attention_mask, **kwargs
 
 
 
271
  ) # (batch_size, seq_length, hidden_size)
272
  # Compute the embeddings
273
  single_vec_emb = self.project_to_single_vector_embeddings(
274
- hidden_states, attention_mask, input_ids=input_ids
 
 
 
275
  )
276
  multi_vec_emb = self.project_to_multi_vector_embeddings(
277
- hidden_states, attention_mask
 
 
278
  )
279
 
280
  return JinaEmbeddingsV4ModelOutput(
@@ -288,6 +326,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
288
  def _process_batches(
289
  self,
290
  data: List[Union[str, Image.Image]],
 
291
  processor_fn: Callable,
292
  desc: str,
293
  vector_type: str = "single_vector",
@@ -307,7 +346,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
307
  with torch.no_grad():
308
  batch = {k: v.to(self.device) for k, v in batch.items()}
309
  with torch.autocast(device_type=torch.device(self.device).type):
310
- embeddings = self(**batch)
311
  if vector_type == "single_vector":
312
  embeddings = embeddings.single_vec_emb
313
  if truncate_dim is not None:
@@ -338,7 +377,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
338
  else:
339
  encode_kwargs["prefix"] = (
340
  PREFIX_DICT[prompt_name]
341
- if self.task != TaskType.text_matching
342
  else PREFIX_DICT["query"]
343
  )
344
 
@@ -351,18 +390,32 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
351
  encode_kwargs["vector_type"] = vector_type
352
 
353
  truncate_dim = truncate_dim or self.config.truncate_dim
354
- if truncate_dim is not None and truncate_dim not in TRUNCATE_DIMS:
355
  raise ValueError(
356
- f"Invalid truncate_dim: {truncate_dim}. Must be one of {TRUNCATE_DIMS}."
357
  )
358
  else:
359
  encode_kwargs["truncate_dim"] = truncate_dim
360
 
361
  return encode_kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  def encode_texts(
364
  self,
365
  texts: List[str],
 
366
  max_length: int = 8192,
367
  batch_size: int = 8,
368
  vector_type: Optional[str] = None,
@@ -390,6 +443,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
390
  vector_type, truncate_dim, prompt_name
391
  )
392
 
 
 
393
  processor_fn = partial(
394
  self.processor.process_texts,
395
  max_length=max_length,
@@ -400,6 +455,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
400
  data=texts,
401
  processor_fn=processor_fn,
402
  desc="Encoding texts...",
 
403
  return_numpy=return_numpy,
404
  batch_size=batch_size,
405
  **encode_kwargs,
@@ -410,6 +466,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
410
  def encode_images(
411
  self,
412
  images: List[Image.Image],
 
413
  batch_size: int = 8,
414
  vector_type: Optional[str] = None,
415
  return_numpy: bool = False,
@@ -432,14 +489,17 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
432
  """
433
  if max_pixels:
434
  default_max_pixels = self.processor.image_processor.max_pixels
435
- self.processor.image_processor.max_pixels = max_pixels # change during encoding
 
 
436
 
437
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
438
-
439
  embeddings = self._process_batches(
440
  data=images,
441
  processor_fn=self.processor.process_images,
442
  desc="Encoding images...",
 
443
  batch_size=batch_size,
444
  return_numpy=return_numpy,
445
  **encode_kwargs,
@@ -463,15 +523,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
463
  if "torch_dtype" not in kwargs:
464
  kwargs["torch_dtype"] = "auto"
465
 
466
- task_value = kwargs.pop("task", "test")
467
- try:
468
- task = TaskType(task_value)
469
- except ValueError:
470
- valid_tasks = [t.value for t in TaskType]
471
- raise ValueError(
472
- f"Invalid task: {task_value}. Must be one of {valid_tasks}."
473
- )
474
-
475
  base_model = super().from_pretrained(
476
  pretrained_model_name_or_path, *args, **kwargs
477
  )
@@ -485,46 +536,31 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
485
  )
486
  adapter_dir = os.path.join(adapter_cache_path, "adapters")
487
 
488
- base_model.adapter_dir = adapter_dir
489
- base_model.task = task
490
-
491
- lora_config = LoraConfig.from_pretrained(os.path.join(adapter_dir, task.value))
492
- lora_config._custom_modules = {torch.nn.modules.linear.Linear: Linear}
493
- # Create the PEFT model with the requested task adapter
 
494
  peft_model = PeftModel.from_pretrained(
495
- model=base_model, model_id=os.path.join(adapter_dir, task.value), config=lora_config
 
 
496
  )
497
 
498
- # Add set_task method to the PEFT model instance
499
- def set_task_method(self, task: Union[str, TaskType]):
500
- """
501
- Set the task adapter for the model.
502
-
503
- Args:
504
- task (Union[str, TaskType]): The task name. Must be one of TaskType values or
505
- one of ['retrieval', 'text-matching', 'code']
506
- """
507
- if isinstance(task, str):
508
- try:
509
- task = TaskType(task)
510
- except ValueError:
511
- valid_tasks = [t.value for t in TaskType]
512
- raise ValueError(
513
- f"Invalid task: {task}. Must be one of {valid_tasks}"
514
- )
515
- if self.model.task != task:
516
- adapter_path = os.path.join(self.adapter_dir, task.value)
517
- hotswap_adapter(self, adapter_path, adapter_name="default")
518
- self.model.task = task
519
-
520
- def get_task_method(self):
521
- """
522
- Get the task adapter for the model.
523
- """
524
- return self.model.task.value
525
-
526
- # Bind the methods to the instance
527
- peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
528
- peft_model.get_task = get_task_method.__get__(peft_model, type(peft_model))
529
 
530
  return peft_model
 
20
  from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
21
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
22
  import peft
23
+ from .custom_lora_module import MultiAdapterLinear
24
+
25
 
26
  class PromptType(str, Enum):
27
  query = "query"
28
  passage = "passage"
29
 
30
 
 
 
 
 
 
 
 
31
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
 
32
  VECTOR_TYPES = ["single_vector", "multi_vector"]
33
 
34
 
 
146
  )
147
  self.single_vector_projector_dim = config.single_vector_projector_dim
148
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
149
+ self._task = None
150
+
151
+ @property
152
+ def task(self) -> Optional[str]:
153
+ """Get the current task set for the model."""
154
+ return self._task
155
+
156
+ @task.setter
157
+ def task(self, task: str):
158
+ """
159
+ Set the task for the model.
160
+
161
+ Args:
162
+ task (str): The task name. Must be one of ['retrieval', 'text-matching', 'code']
163
+ """
164
+ if task not in self.config.task_names:
165
+ raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
166
+ self._task = task
167
 
168
  def get_last_hidden_states(
169
  self,
170
+ task_label: Union[str, List[str]],
171
  input_ids: torch.LongTensor,
172
  attention_mask: torch.Tensor,
173
  **kwargs,
 
186
 
187
  kwargs["output_hidden_states"] = True
188
  outputs = super().forward(
189
+ task_label=task_label,
190
+ input_ids=input_ids,
191
+ attention_mask=attention_mask,
192
  **kwargs,
193
  position_ids=position_ids,
194
  rope_deltas=rope_deltas,
 
220
 
221
  def project_to_single_vector_embeddings(
222
  self,
223
+ task_label: Union[str, List[str]],
224
  hidden_states: torch.Tensor,
225
  attention_mask: torch.Tensor,
226
  input_ids: Optional[torch.LongTensor] = None,
 
229
  Project the hidden states to single-vector embeddings.
230
  """
231
  if self._input_has_image(input_ids[0]): # got document image
232
+ img_start_positions = torch.where(
233
+ input_ids == self.config.vision_start_token_id
234
+ )[1]
235
+ img_end_positions = torch.where(
236
+ input_ids == self.config.vision_end_token_id
237
+ )[1]
238
+
239
  batch_size, seq_len = input_ids.shape
240
+ position_indices = torch.arange(seq_len, device=input_ids.device).expand(
241
+ batch_size, -1
242
+ )
243
+ image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (
244
+ position_indices <= img_end_positions.unsqueeze(1)
245
+ )
246
+
247
  masked_hidden_states = hidden_states * image_mask.unsqueeze(-1)
248
+ pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(
249
+ dim=1, keepdim=True
250
+ )
251
 
252
  else: # got query text
253
  pooled_output = torch.sum(
254
  hidden_states * attention_mask.unsqueeze(-1), dim=1
255
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
256
 
257
+ single_vec_emb = self.single_vector_projector(
258
+ pooled_output, task_label=task_label
259
+ )
260
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
261
 
262
  def project_to_multi_vector_embeddings(
263
  self,
264
+ task_label: Union[str, List[str]],
265
  hidden_states: torch.Tensor,
266
  attention_mask: torch.Tensor,
267
  ) -> torch.Tensor:
268
  """
269
  Project the hidden states to multi-vector embeddings.
270
  """
271
+ multi_vec_emb = self.multi_vector_projector(
272
+ hidden_states, task_label=task_label
273
+ )
274
  multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1)
275
  return multi_vec_emb * attention_mask.unsqueeze(-1)
276
 
 
279
 
280
  def forward(
281
  self,
282
+ task_label: Union[str, List[str]],
283
  input_ids: torch.LongTensor,
284
  attention_mask: torch.Tensor,
285
  output_vlm_last_hidden_states: bool = False,
 
297
  """
298
  # Forward pass through the VLM
299
  hidden_states = self.get_last_hidden_states(
300
+ input_ids=input_ids,
301
+ attention_mask=attention_mask,
302
+ task_label=task_label,
303
+ **kwargs,
304
  ) # (batch_size, seq_length, hidden_size)
305
  # Compute the embeddings
306
  single_vec_emb = self.project_to_single_vector_embeddings(
307
+ hidden_states=hidden_states,
308
+ attention_mask=attention_mask,
309
+ input_ids=input_ids,
310
+ task_label=task_label,
311
  )
312
  multi_vec_emb = self.project_to_multi_vector_embeddings(
313
+ hidden_states=hidden_states,
314
+ attention_mask=attention_mask,
315
+ task_label=task_label,
316
  )
317
 
318
  return JinaEmbeddingsV4ModelOutput(
 
326
  def _process_batches(
327
  self,
328
  data: List[Union[str, Image.Image]],
329
+ task_label: Union[str, List[str]],
330
  processor_fn: Callable,
331
  desc: str,
332
  vector_type: str = "single_vector",
 
346
  with torch.no_grad():
347
  batch = {k: v.to(self.device) for k, v in batch.items()}
348
  with torch.autocast(device_type=torch.device(self.device).type):
349
+ embeddings = self(**batch, task_label=task_label)
350
  if vector_type == "single_vector":
351
  embeddings = embeddings.single_vec_emb
352
  if truncate_dim is not None:
 
377
  else:
378
  encode_kwargs["prefix"] = (
379
  PREFIX_DICT[prompt_name]
380
+ if self.task != "text-matching"
381
  else PREFIX_DICT["query"]
382
  )
383
 
 
390
  encode_kwargs["vector_type"] = vector_type
391
 
392
  truncate_dim = truncate_dim or self.config.truncate_dim
393
+ if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
394
  raise ValueError(
395
+ f"Invalid truncate_dim: {truncate_dim}. Must be one of {self.config.matryoshka_dims}."
396
  )
397
  else:
398
  encode_kwargs["truncate_dim"] = truncate_dim
399
 
400
  return encode_kwargs
401
+
402
+ def _validate_task(self, task: Optional[str] = None) -> str:
403
+ if task is None:
404
+ if self.task is None:
405
+ raise ValueError(
406
+ "Task must be specified before encoding data. You can set it either as a model property "
407
+ "(e.g., model.task = 'retrieval') or pass it as an argument to the encode method."
408
+ )
409
+ task = self.task
410
+ else:
411
+ if task not in self.config.task_names:
412
+ raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
413
+ return task
414
 
415
  def encode_texts(
416
  self,
417
  texts: List[str],
418
+ task: Optional[str] = None,
419
  max_length: int = 8192,
420
  batch_size: int = 8,
421
  vector_type: Optional[str] = None,
 
443
  vector_type, truncate_dim, prompt_name
444
  )
445
 
446
+ task = self._validate_task(task)
447
+
448
  processor_fn = partial(
449
  self.processor.process_texts,
450
  max_length=max_length,
 
455
  data=texts,
456
  processor_fn=processor_fn,
457
  desc="Encoding texts...",
458
+ task_label=task,
459
  return_numpy=return_numpy,
460
  batch_size=batch_size,
461
  **encode_kwargs,
 
466
  def encode_images(
467
  self,
468
  images: List[Image.Image],
469
+ task: Optional[str] = None,
470
  batch_size: int = 8,
471
  vector_type: Optional[str] = None,
472
  return_numpy: bool = False,
 
489
  """
490
  if max_pixels:
491
  default_max_pixels = self.processor.image_processor.max_pixels
492
+ self.processor.image_processor.max_pixels = (
493
+ max_pixels # change during encoding
494
+ )
495
 
496
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
497
+ task = self._validate_task(task)
498
  embeddings = self._process_batches(
499
  data=images,
500
  processor_fn=self.processor.process_images,
501
  desc="Encoding images...",
502
+ task_label=task,
503
  batch_size=batch_size,
504
  return_numpy=return_numpy,
505
  **encode_kwargs,
 
523
  if "torch_dtype" not in kwargs:
524
  kwargs["torch_dtype"] = "auto"
525
 
 
 
 
 
 
 
 
 
 
526
  base_model = super().from_pretrained(
527
  pretrained_model_name_or_path, *args, **kwargs
528
  )
 
536
  )
537
  adapter_dir = os.path.join(adapter_cache_path, "adapters")
538
 
539
+ lora_config = LoraConfig.from_pretrained(os.path.join(adapter_dir, "test"))
540
+ lora_config._custom_modules = {
541
+ torch.nn.modules.linear.Linear: partial(
542
+ MultiAdapterLinear,
543
+ task_names=base_model.config.task_names,
544
+ )
545
+ }
546
  peft_model = PeftModel.from_pretrained(
547
+ model=base_model,
548
+ model_id=os.path.join(adapter_dir, "test"),
549
+ config=lora_config,
550
  )
551
 
552
+ @property
553
+ def task(self):
554
+ return self.model.task
555
+
556
+ @task.setter
557
+ def task(self, value):
558
+ self.model.task = value
559
+
560
+ peft_model.task = property(task.fget, task.fset)
561
+ peft_model.__class__.task = property(
562
+ lambda self: self.model.task,
563
+ lambda self, value: setattr(self.model, 'task', value)
564
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
  return peft_model
qwen2_5_vl.py CHANGED
@@ -1,28 +1,6 @@
1
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
- # This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
3
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
- # the file from the modular. If any change should be done, please apply the change to the
5
- # modular_qwen2_5_vl.py file directly. One of our CI enforces this.
6
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- # coding=utf-8
8
- # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
9
- #
10
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
11
- # and OPT implementations in this library. It has been modified from its
12
- # original forms to accommodate minor architectural differences compared
13
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
14
- #
15
- # Licensed under the Apache License, Version 2.0 (the "License");
16
- # you may not use this file except in compliance with the License.
17
- # You may obtain a copy of the License at
18
- #
19
- # http://www.apache.org/licenses/LICENSE-2.0
20
- #
21
- # Unless required by applicable law or agreed to in writing, software
22
- # distributed under the License is distributed on an "AS IS" BASIS,
23
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24
- # See the License for the specific language governing permissions and
25
- # limitations under the License.
26
  from transformers.configuration_utils import PretrainedConfig
27
  from transformers.modeling_rope_utils import rope_config_validation
28
 
@@ -256,32 +234,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
256
 
257
 
258
 
259
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
260
- # This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
261
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
262
- # the file from the modular. If any change should be done, please apply the change to the
263
- # modular_qwen2_5_vl.py file directly. One of our CI enforces this.
264
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
265
- # coding=utf-8
266
- # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
267
- #
268
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
269
- # and OPT implementations in this library. It has been modified from its
270
- # original forms to accommodate minor architectural differences compared
271
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
272
- #
273
- # Licensed under the Apache License, Version 2.0 (the "License");
274
- # you may not use this file except in compliance with the License.
275
- # You may obtain a copy of the License at
276
- #
277
- # http://www.apache.org/licenses/LICENSE-2.0
278
- #
279
- # Unless required by applicable law or agreed to in writing, software
280
- # distributed under the License is distributed on an "AS IS" BASIS,
281
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
282
- # See the License for the specific language governing permissions and
283
- # limitations under the License.
284
-
285
  import math
286
  from dataclasses import dataclass
287
  from typing import Any, Dict, List, Optional, Tuple, Union
@@ -891,8 +843,8 @@ class Qwen2MLP(nn.Module):
891
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
892
  self.act_fn = ACT2FN[config.hidden_act]
893
 
894
- def forward(self, x):
895
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
896
  return down_proj
897
 
898
 
@@ -1179,6 +1131,7 @@ class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention):
1179
  # Adapted from Qwen2Attention.forward
1180
  def forward(
1181
  self,
 
1182
  hidden_states: torch.Tensor,
1183
  attention_mask: Optional[torch.Tensor] = None,
1184
  position_ids: Optional[torch.LongTensor] = None,
@@ -1207,9 +1160,9 @@ class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention):
1207
 
1208
  bsz, q_len, _ = hidden_states.size()
1209
 
1210
- query_states = self.q_proj(hidden_states)
1211
- key_states = self.k_proj(hidden_states)
1212
- value_states = self.v_proj(hidden_states)
1213
 
1214
  query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1215
  key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
@@ -1255,7 +1208,7 @@ class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention):
1255
  attn_output = attn_output.transpose(1, 2).contiguous()
1256
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1257
 
1258
- attn_output = self.o_proj(attn_output)
1259
 
1260
  return attn_output, None, past_key_value
1261
 
@@ -1285,6 +1238,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
1285
 
1286
  def forward(
1287
  self,
 
1288
  hidden_states: torch.Tensor,
1289
  attention_mask: Optional[torch.Tensor] = None,
1290
  position_ids: Optional[torch.LongTensor] = None,
@@ -1323,6 +1277,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
1323
 
1324
  # Self Attention
1325
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
 
1326
  hidden_states=hidden_states,
1327
  attention_mask=attention_mask,
1328
  position_ids=position_ids,
@@ -1337,7 +1292,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
1337
  # Fully Connected
1338
  residual = hidden_states
1339
  hidden_states = self.post_attention_layernorm(hidden_states)
1340
- hidden_states = self.mlp(hidden_states)
1341
  hidden_states = residual + hidden_states
1342
 
1343
  outputs = (hidden_states,)
@@ -1381,6 +1336,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
1381
 
1382
  def forward(
1383
  self,
 
1384
  input_ids: torch.LongTensor = None,
1385
  attention_mask: Optional[torch.Tensor] = None,
1386
  position_ids: Optional[torch.LongTensor] = None,
@@ -1461,7 +1417,8 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
1461
  )
1462
  else:
1463
  layer_outputs = decoder_layer(
1464
- hidden_states,
 
1465
  attention_mask=causal_mask,
1466
  position_ids=position_ids,
1467
  past_key_value=past_key_values,
@@ -1979,6 +1936,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
1979
  @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1980
  def forward(
1981
  self,
 
1982
  input_ids: torch.LongTensor = None,
1983
  attention_mask: Optional[torch.Tensor] = None,
1984
  position_ids: Optional[torch.LongTensor] = None,
@@ -2115,6 +2073,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
2115
  position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
2116
 
2117
  outputs = self.model(
 
2118
  input_ids=None,
2119
  position_ids=position_ids,
2120
  attention_mask=attention_mask,
@@ -2324,32 +2283,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
2324
  return input_ids, model_kwargs
2325
 
2326
 
2327
-
2328
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2329
- # This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
2330
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
2331
- # the file from the modular. If any change should be done, please apply the change to the
2332
- # modular_qwen2_5_vl.py file directly. One of our CI enforces this.
2333
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2334
- # coding=utf-8
2335
- # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
2336
- #
2337
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
2338
- # and OPT implementations in this library. It has been modified from its
2339
- # original forms to accommodate minor architectural differences compared
2340
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
2341
- #
2342
- # Licensed under the Apache License, Version 2.0 (the "License");
2343
- # you may not use this file except in compliance with the License.
2344
- # You may obtain a copy of the License at
2345
- #
2346
- # http://www.apache.org/licenses/LICENSE-2.0
2347
- #
2348
- # Unless required by applicable law or agreed to in writing, software
2349
- # distributed under the License is distributed on an "AS IS" BASIS,
2350
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2351
- # See the License for the specific language governing permissions and
2352
- # limitations under the License.
2353
  from typing import List, Union
2354
 
2355
  from transformers.feature_extraction_utils import BatchFeature
 
1
+ # This file is a modified version of the Qwen2_5_VL model from the transformers library
2
+ # that implements task-specific LoRA layers for multi-task embeddings.
3
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from transformers.configuration_utils import PretrainedConfig
5
  from transformers.modeling_rope_utils import rope_config_validation
6
 
 
234
 
235
 
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  import math
238
  from dataclasses import dataclass
239
  from typing import Any, Dict, List, Optional, Tuple, Union
 
843
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
844
  self.act_fn = ACT2FN[config.hidden_act]
845
 
846
+ def forward(self, x, task_label: Union[str, List[str]]):
847
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x, task_label=task_label)) * self.up_proj(x, task_label=task_label), task_label=task_label)
848
  return down_proj
849
 
850
 
 
1131
  # Adapted from Qwen2Attention.forward
1132
  def forward(
1133
  self,
1134
+ task_label: Union[str, List[str]],
1135
  hidden_states: torch.Tensor,
1136
  attention_mask: Optional[torch.Tensor] = None,
1137
  position_ids: Optional[torch.LongTensor] = None,
 
1160
 
1161
  bsz, q_len, _ = hidden_states.size()
1162
 
1163
+ query_states = self.q_proj(hidden_states, task_label=task_label)
1164
+ key_states = self.k_proj(hidden_states, task_label=task_label)
1165
+ value_states = self.v_proj(hidden_states, task_label=task_label)
1166
 
1167
  query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
1168
  key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
 
1208
  attn_output = attn_output.transpose(1, 2).contiguous()
1209
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1210
 
1211
+ attn_output = self.o_proj(attn_output, task_label=task_label)
1212
 
1213
  return attn_output, None, past_key_value
1214
 
 
1238
 
1239
  def forward(
1240
  self,
1241
+ task_label: Union[str, List[str]],
1242
  hidden_states: torch.Tensor,
1243
  attention_mask: Optional[torch.Tensor] = None,
1244
  position_ids: Optional[torch.LongTensor] = None,
 
1277
 
1278
  # Self Attention
1279
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
1280
+ task_label=task_label,
1281
  hidden_states=hidden_states,
1282
  attention_mask=attention_mask,
1283
  position_ids=position_ids,
 
1292
  # Fully Connected
1293
  residual = hidden_states
1294
  hidden_states = self.post_attention_layernorm(hidden_states)
1295
+ hidden_states = self.mlp(hidden_states, task_label=task_label)
1296
  hidden_states = residual + hidden_states
1297
 
1298
  outputs = (hidden_states,)
 
1336
 
1337
  def forward(
1338
  self,
1339
+ task_label: Union[str, List[str]],
1340
  input_ids: torch.LongTensor = None,
1341
  attention_mask: Optional[torch.Tensor] = None,
1342
  position_ids: Optional[torch.LongTensor] = None,
 
1417
  )
1418
  else:
1419
  layer_outputs = decoder_layer(
1420
+ task_label=task_label,
1421
+ hidden_states=hidden_states,
1422
  attention_mask=causal_mask,
1423
  position_ids=position_ids,
1424
  past_key_value=past_key_values,
 
1936
  @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1937
  def forward(
1938
  self,
1939
+ task_label: Union[str, List[str]],
1940
  input_ids: torch.LongTensor = None,
1941
  attention_mask: Optional[torch.Tensor] = None,
1942
  position_ids: Optional[torch.LongTensor] = None,
 
2073
  position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
2074
 
2075
  outputs = self.model(
2076
+ task_label=task_label,
2077
  input_ids=None,
2078
  position_ids=position_ids,
2079
  attention_mask=attention_mask,
 
2283
  return input_ids, model_kwargs
2284
 
2285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2286
  from typing import List, Union
2287
 
2288
  from transformers.feature_extraction_utils import BatchFeature