|
|
|
import os |
|
import math |
|
import numpy as np |
|
|
|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast |
|
from typing_extensions import Unpack |
|
|
|
import torch |
|
from torch import nn |
|
from torch.utils.data import DataLoader |
|
|
|
from functools import partial |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from enum import Enum |
|
|
|
from transformers import BatchEncoding, BatchFeature |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast |
|
|
|
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration |
|
|
|
from transformers.processing_utils import ( |
|
AllKwargsForChatTemplate, |
|
ImageInput, |
|
PreTokenizedInput, |
|
TextInput, |
|
VideoInput, |
|
) |
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
from .configuration_colqwen_duo import ColQwen25DuoConfig |
|
|
|
|
|
def get_torch_device(device: str = "auto") -> str: |
|
""" |
|
Returns the device (string) to be used by PyTorch. |
|
|
|
`device` arg defaults to "auto" which will use: |
|
- "cuda:0" if available |
|
- else "mps" if available |
|
- else "cpu". |
|
""" |
|
|
|
if device == "auto": |
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
else: |
|
device = "cpu" |
|
logger.info(f"Using device: {device}") |
|
|
|
return device |
|
|
|
|
|
class PromptType(str, Enum): |
|
query = "query" |
|
passage = "passage" |
|
|
|
|
|
|
|
|
|
class BaseVisualRetrieverProcessor(ABC): |
|
""" |
|
Base class for visual retriever processors. |
|
""" |
|
|
|
@abstractmethod |
|
def process_images( |
|
self, |
|
images: List[Image.Image], |
|
) -> Union[BatchFeature, BatchEncoding]: |
|
pass |
|
|
|
@abstractmethod |
|
def process_texts( |
|
self, |
|
texts: List[str], |
|
max_length: int = 50, |
|
suffix: Optional[str] = None, |
|
prefix: Optional[str] = None, |
|
) -> Union[BatchFeature, BatchEncoding]: |
|
pass |
|
|
|
@abstractmethod |
|
def score( |
|
self, |
|
qs: List[torch.Tensor], |
|
ps: List[torch.Tensor], |
|
device: Optional[Union[str, torch.device]] = None, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
pass |
|
|
|
@staticmethod |
|
def score_single_vector( |
|
qs: List[torch.Tensor], |
|
ps: List[torch.Tensor], |
|
device: Optional[Union[str, torch.device]] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the dot product score for the given single-vector query and passage embeddings. |
|
""" |
|
device = device or get_torch_device("auto") |
|
|
|
if len(qs) == 0: |
|
raise ValueError("No queries provided") |
|
if len(ps) == 0: |
|
raise ValueError("No passages provided") |
|
|
|
qs_stacked = torch.stack(qs).to(device) |
|
ps_stacked = torch.stack(ps).to(device) |
|
|
|
scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) |
|
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
|
scores = scores.to(torch.float32) |
|
return scores |
|
|
|
@staticmethod |
|
def score_multi_vector( |
|
qs: List[torch.Tensor], |
|
ps: List[torch.Tensor], |
|
batch_size: int = 128, |
|
device: Optional[Union[str, torch.device]] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. |
|
""" |
|
device = device or get_torch_device("auto") |
|
|
|
if len(qs) == 0: |
|
raise ValueError("No queries provided") |
|
if len(ps) == 0: |
|
raise ValueError("No passages provided") |
|
|
|
scores_list: List[torch.Tensor] = [] |
|
|
|
for i in range(0, len(qs), batch_size): |
|
scores_batch = [] |
|
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to( |
|
device |
|
) |
|
for j in range(0, len(ps), batch_size): |
|
ps_batch = torch.nn.utils.rnn.pad_sequence( |
|
ps[j : j + batch_size], batch_first=True, padding_value=0 |
|
).to(device) |
|
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) |
|
scores_batch = torch.cat(scores_batch, dim=1).cpu() |
|
scores_list.append(scores_batch) |
|
|
|
scores = torch.cat(scores_list, dim=0) |
|
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
|
scores = scores.to(torch.float32) |
|
return scores |
|
|
|
|
|
class QwenVLProcessor(ABC): |
|
|
|
def __call__( |
|
self, |
|
images: Optional[ImageInput] = None, |
|
text: Optional[Union[TextInput, PreTokenizedInput, List[PreTokenizedInput]]] = None, |
|
videos: Optional[VideoInput] = None, |
|
**kwargs, |
|
) -> BatchFeature: |
|
return super().__call__(images=images, text=text, videos=videos, **kwargs) |
|
|
|
def apply_chat_template( |
|
self, |
|
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], |
|
chat_template: Optional[str] = None, |
|
**kwargs: Unpack[AllKwargsForChatTemplate], |
|
) -> str: |
|
return super().apply_chat_template(conversation=conversation, chat_template=chat_template, **kwargs) |
|
|
|
|
|
class QwenVLEmbeddingProcessorBase(BaseVisualRetrieverProcessor, QwenVLProcessor): |
|
|
|
assistant_prefix_len: int = 58 |
|
|
|
|
|
@staticmethod |
|
def round_by_factor(number: float, factor: int) -> int: |
|
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" |
|
return round(number / factor) * factor |
|
|
|
@staticmethod |
|
def ceil_by_factor(number: float, factor: int) -> int: |
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.ceil(number / factor) * factor |
|
|
|
@staticmethod |
|
def floor_by_factor(number: float, factor: int) -> int: |
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.floor(number / factor) * factor |
|
|
|
def process_images( |
|
self, |
|
images: Union[List[Image.Image], List[List[Image.Image]]], |
|
) -> BatchFeature: |
|
|
|
if isinstance(images[0], list): |
|
images = cast(List[List[Image.Image]], images) |
|
text_doc = [] |
|
for i in range(len(images)): |
|
conversation = [{"role": "user", "content": [{"type": "image"}] * len(images[i])}] |
|
template = self.apply_chat_template(conversation, add_generation_prompt=False) |
|
text_doc.append(template[self.assistant_prefix_len :]) |
|
|
|
else: |
|
images = cast(List[Image.Image], images) |
|
text_doc = [ |
|
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n" |
|
] * len(images) |
|
|
|
|
|
batch_doc = self(text=text_doc, images=images, padding="longest", return_tensors="pt") |
|
|
|
offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2] |
|
|
|
pixel_values = torch.split(batch_doc["pixel_values"], offsets.tolist()) |
|
|
|
max_length = max([len(pv) for pv in pixel_values]) |
|
|
|
pixel_values = [ |
|
torch.cat([pv, torch.zeros((max_length - len(pv), pv.shape[1]), dtype=pv.dtype, device=pv.device)]) |
|
for pv in pixel_values |
|
] |
|
|
|
batch_doc["pixel_values"] = torch.stack(pixel_values) |
|
return batch_doc |
|
|
|
def process_texts( |
|
self, |
|
texts: List[str], |
|
max_length: int = 8192, |
|
suffix: Optional[str] = None, |
|
prefix: Optional[str] = None, |
|
padding: Optional[str] = None, |
|
) -> BatchFeature: |
|
|
|
if suffix is None: |
|
suffix = "<pad>" * 10 |
|
|
|
padded_texts: List[str] = [] |
|
|
|
for text in texts: |
|
if prefix: |
|
text = f"{prefix}: {text}" |
|
text += suffix |
|
padded_texts.append(text) |
|
|
|
text_batch = self( |
|
text=padded_texts, |
|
return_tensors="pt", |
|
padding=padding or "longest", |
|
max_length=max_length, |
|
truncation=True, |
|
) |
|
|
|
return text_batch |
|
|
|
|
|
class ColQwenDuoProcessorBase(QwenVLEmbeddingProcessorBase): |
|
""" |
|
Processor for ColQwenDuo. Mirrors the `ColQwen2Processor` class. |
|
""" |
|
|
|
def score( |
|
self, |
|
qs: List[torch.Tensor], |
|
ps: List[torch.Tensor], |
|
vector_type: str, |
|
device: Optional[Union[str, torch.device]] = None, |
|
truncate: Optional[int] = None, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. |
|
""" |
|
if truncate: |
|
qs = [q[..., :truncate] for q in qs] |
|
ps = [p[..., :truncate] for p in ps] |
|
|
|
if vector_type == "single_vector": |
|
return self.score_single_vector(qs, ps, device=device) |
|
elif vector_type == "multi_vector": |
|
return self.score_multi_vector(qs, ps, device=device, **kwargs) |
|
else: |
|
raise ValueError('vector_type must be one of the following: [`single_vector`, `multi_vector`]') |
|
|
|
|
|
class ColQwen25DuoProcessor(ColQwenDuoProcessorBase, Qwen2_5_VLProcessor): |
|
def __init__(self, *args, **kwargs) -> None: |
|
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs) |
|
|
|
|
|
@dataclass |
|
class HybridModelOutput: |
|
""" |
|
Base class for the Hybrid Model outputs. |
|
Args: |
|
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM. |
|
single_vec_emb (torch.Tensor, optional): Single-vector embeddings. |
|
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings. |
|
""" |
|
|
|
vlm_last_hidden_states: Optional[torch.Tensor] = None |
|
single_vec_emb: Optional[torch.Tensor] = None |
|
multi_vec_emb: Optional[torch.Tensor] = None |
|
|
|
class EncodeMixin: |
|
""" |
|
Interface to encode data for MTEB and ViDoRe evaluations. |
|
""" |
|
|
|
def _process_batches( |
|
self, |
|
data: List[Union[str, Image.Image]], |
|
processor_fn: Callable, |
|
desc: str, |
|
vector_type: Optional[str] = None, |
|
return_numpy: bool = False, |
|
**kwargs, |
|
) -> Union[np.ndarray, List[torch.Tensor]]: |
|
dataloader = DataLoader( |
|
dataset=data, |
|
batch_size=kwargs.get("batch_size", 32), |
|
shuffle=False, |
|
collate_fn=processor_fn, |
|
) |
|
results = [] |
|
self.eval() |
|
for batch in tqdm(dataloader, desc=desc): |
|
with torch.no_grad(): |
|
batch = {k: v.to(self.device) for k, v in batch.items()} |
|
with torch.autocast(device_type=torch.device(self.device).type): |
|
embeddings = self(**batch) |
|
if isinstance(embeddings, HybridModelOutput) and (vector_type == "single_vector"): |
|
embeddings = embeddings.single_vec_emb |
|
elif isinstance(embeddings, HybridModelOutput) and (vector_type == "multi_vector"): |
|
embeddings = embeddings.multi_vec_emb |
|
elif not vector_type and isinstance(embeddings, HybridModelOutput): |
|
embeddings = embeddings.single_vec_emb |
|
results.append(embeddings.cpu() if return_numpy else list(torch.unbind(embeddings))) |
|
if return_numpy: |
|
return np.concatenate([result.numpy() for result in results], axis=0) |
|
return [item for sublist in results for item in sublist] |
|
|
|
def encode( |
|
self, |
|
sentences: List[str], |
|
max_length: int = 8192, |
|
batch_size: int = 8, |
|
prefixes: Optional[List[str]] = None, |
|
desc: Optional[str] = None, |
|
vector_type: Optional[str] = None, |
|
padding: Optional[str] = None, |
|
prompt_type: Optional[PromptType] = None, |
|
**kwargs, |
|
) -> np.ndarray: |
|
prefix = None |
|
if isinstance(prefixes, list) and len(prefixes) > 0: |
|
if prompt_type: |
|
desc = f"MTEB: Encode {prompt_type.value}..." |
|
prefix = prefixes[0] if prompt_type.value == "query" else prefixes[1] |
|
else: |
|
prefix = prefixes[0] |
|
processor_fn = partial(self.processor.process_texts, max_length=max_length, prefix=prefix, padding=padding) |
|
desc = desc or "MTEB: Encode texts..." |
|
return self._process_batches( |
|
data=sentences, |
|
processor_fn=processor_fn, |
|
desc=desc, |
|
vector_type=vector_type, |
|
batch_size=batch_size, |
|
**kwargs, |
|
) |
|
|
|
def encode_texts( |
|
self, |
|
queries: List[str], |
|
max_length: int = 8192, |
|
batch_size: int = 8, |
|
vector_type: Optional[str] = None, |
|
desc: Optional[str] = None, |
|
**kwargs, |
|
) -> List[torch.Tensor]: |
|
processor_fn = partial(self.processor.process_texts, max_length=max_length, prefix="Query") |
|
return self._process_batches( |
|
data=queries, |
|
processor_fn=processor_fn, |
|
desc=desc or "Encode queries...", |
|
vector_type=vector_type, |
|
batch_size=batch_size, |
|
**kwargs, |
|
) |
|
|
|
def encode_images( |
|
self, |
|
documents: List[Image.Image], |
|
batch_size: int = 8, |
|
vector_type: Optional[str] = None, |
|
desc: Optional[str] = None, |
|
**kwargs, |
|
) -> List[torch.Tensor]: |
|
return self._process_batches( |
|
data=documents, |
|
processor_fn=self.processor.process_images, |
|
desc=desc or "Encode documents...", |
|
vector_type=vector_type, |
|
batch_size=batch_size, |
|
**kwargs, |
|
) |
|
|
|
class QwenVLModel(ABC): |
|
|
|
def get_rope_index( |
|
self, |
|
input_ids: torch.LongTensor, |
|
image_grid_thw: Union[torch.LongTensor, None], |
|
attention_mask: torch.Tensor, |
|
) -> tuple[torch.LongTensor, torch.Tensor]: |
|
return super().get_rope_index( |
|
input_ids=input_ids, |
|
image_grid_thw=image_grid_thw, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.Tensor, |
|
position_ids: torch.LongTensor, |
|
rope_deltas: torch.Tensor, |
|
output_hidden_states: bool, |
|
use_cache: bool, |
|
**kwargs, |
|
) -> Qwen2VLCausalLMOutputWithPast: |
|
return super().forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
rope_deltas=rope_deltas, |
|
output_hidden_states=output_hidden_states, |
|
use_cache=use_cache, |
|
**kwargs, |
|
) |
|
|
|
|
|
class QwenVLEmbeddingBase(EncodeMixin, QwenVLModel): |
|
main_input_name: ClassVar[str] = "doc_input_ids" |
|
|
|
def get_vlm_last_hidden_states( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.Tensor, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
if "pixel_values" in kwargs: |
|
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] |
|
kwargs["pixel_values"] = torch.cat([pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0) |
|
|
|
position_ids, rope_deltas = self.get_rope_index( |
|
input_ids=input_ids, |
|
image_grid_thw=kwargs.get("image_grid_thw", None), |
|
attention_mask=attention_mask, |
|
) |
|
|
|
outputs = super().forward( |
|
input_ids, |
|
attention_mask, |
|
**kwargs, |
|
position_ids=position_ids, |
|
rope_deltas=rope_deltas, |
|
output_hidden_states=True, |
|
use_cache=False, |
|
) |
|
|
|
hidden_states = outputs.hidden_states |
|
if not hidden_states: |
|
raise ValueError("Hidden states not found in model output") |
|
|
|
return hidden_states[-1] |
|
|
|
|
|
class AbstractHybridModel(ABC): |
|
""" |
|
Abstract class for a hybrid model (single-vector and multi-vector embeddings). |
|
""" |
|
|
|
@property |
|
def single_vector_projector_dim(self) -> int: |
|
return self.config.single_vector_projector_dim |
|
|
|
@property |
|
def multi_vector_projector_dim(self) -> int: |
|
return self.config.multi_vector_projector_dim |
|
|
|
@abstractmethod |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.Tensor, |
|
output_vlm_last_hidden_states: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> HybridModelOutput: |
|
""" |
|
Forward pass through the model. Returns both single-vector and multi-vector embeddings. |
|
Must be implemented by subclasses. |
|
""" |
|
pass |
|
|
|
def _init_projection_layers(self, config) -> None: |
|
""" |
|
Initializes projection layers. |
|
""" |
|
self.config.single_vector_projector_dim = config.single_vector_projector_dim |
|
self.config.multi_vector_projector_dim = config.multi_vector_projector_dim |
|
|
|
self.single_vector_projector = nn.Linear( |
|
in_features=self.config.hidden_size, |
|
out_features=self.config.single_vector_projector_dim, |
|
) |
|
|
|
self.multi_vector_projector = nn.Linear( |
|
in_features=self.config.hidden_size, |
|
out_features=self.config.multi_vector_projector_dim, |
|
) |
|
|
|
@staticmethod |
|
def _delete_redundant_forward_kwargs(kwargs: Dict[str, Any]) -> None: |
|
""" |
|
Delete redundant kwargs before passing them to the forward method. In-place operation. |
|
""" |
|
for key in ["input_ids", "attention_mask", "output_hidden_states"]: |
|
kwargs.pop(key, None) |
|
|
|
def project_to_single_vector_embeddings( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Project the hidden states to single-vector embeddings. |
|
""" |
|
|
|
pooling_method = self.config.single_vector_pool_strategy |
|
|
|
if pooling_method == "mean" and input_ids is None: |
|
print("Warning: `input_ids` is None. Using `legacy-mean` pooling strategy instead.") |
|
pooling_method = "legacy-mean" |
|
|
|
if pooling_method == "last-token": |
|
pooled_output = hidden_states[:, -1, :] |
|
elif pooling_method == "mean": |
|
if self._input_has_image(input_ids[0]): |
|
|
|
|
|
|
|
|
|
|
|
input_seq_idx, img_start_pos = torch.where( |
|
input_ids == self.config.vision_start_token_id |
|
) |
|
_, img_end_pos = torch.where( |
|
input_ids == self.config.vision_end_token_id |
|
) |
|
means = [] |
|
for i in range(input_seq_idx.shape[0]): |
|
vector_pos = input_seq_idx[i] |
|
start = img_start_pos[i] |
|
end = img_end_pos[i] |
|
mean_value = hidden_states[vector_pos][start : end + 1].mean(dim=0) |
|
means.append(mean_value) |
|
pooled_output = torch.stack(means) |
|
|
|
else: |
|
pooled_output = torch.sum(hidden_states * attention_mask.unsqueeze(-1), dim=1) / torch.sum( |
|
attention_mask, dim=1, keepdim=True |
|
) |
|
|
|
elif pooling_method == "legacy-mean": |
|
pooled_output = torch.sum(hidden_states * attention_mask.unsqueeze(-1), dim=1) / torch.sum( |
|
attention_mask, dim=1, keepdim=True |
|
) |
|
else: |
|
raise ValueError(f"Invalid pooling strategy: {pooling_method}") |
|
single_vec_emb = self.single_vector_projector(pooled_output) |
|
return torch.nn.functional.normalize(single_vec_emb, dim=-1) |
|
|
|
def project_to_multi_vector_embeddings( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Project the hidden states to multi-vector embeddings. |
|
""" |
|
multi_vec_emb = self.multi_vector_projector(hidden_states) |
|
multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1) |
|
return multi_vec_emb * attention_mask.unsqueeze(-1) |
|
|
|
def _input_has_image(self, input_ids): |
|
return self.config.vision_start_token_id in input_ids |
|
|
|
class ColQwenDuoBase(AbstractHybridModel, QwenVLEmbeddingBase): |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.Tensor, |
|
output_vlm_last_hidden_states: bool = False, |
|
**kwargs, |
|
) -> HybridModelOutput: |
|
""" |
|
Forward pass through ColQwenDuo. Returns both single-vector and multi-vector embeddings. |
|
Args: |
|
input_ids (torch.LongTensor): The input tokens tensor. |
|
attention_mask (torch.LongTensor): The attention mask tensor. |
|
Returns: |
|
HybridModelOutput: |
|
single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim). |
|
multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim). |
|
""" |
|
|
|
self._delete_redundant_forward_kwargs(kwargs) |
|
|
|
|
|
hidden_states = self.get_vlm_last_hidden_states( |
|
input_ids=input_ids, attention_mask=attention_mask, **kwargs |
|
) |
|
|
|
|
|
single_vec_emb = self.project_to_single_vector_embeddings(hidden_states, attention_mask, input_ids=input_ids) |
|
multi_vec_emb = self.project_to_multi_vector_embeddings(hidden_states, attention_mask) |
|
|
|
return HybridModelOutput( |
|
vlm_last_hidden_states=hidden_states if output_vlm_last_hidden_states else None, |
|
single_vec_emb=single_vec_emb, |
|
multi_vec_emb=multi_vec_emb, |
|
) |
|
|
|
|
|
class ColQwen25Duo(ColQwenDuoBase, Qwen2_5_VLForConditionalGeneration): |
|
config_class = ColQwen25DuoConfig |
|
def __init__(self, config: ColQwen25DuoConfig): |
|
Qwen2_5_VLForConditionalGeneration.__init__(self, config) |
|
self._init_projection_layers(config) |
|
self.post_init() |
|
self.processor = ColQwen25DuoProcessor.from_pretrained(self.name_or_path, trust_remote_code=True) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
*args, |
|
**kwargs, |
|
): |
|
if not "torch_dtype" in kwargs: |
|
kwargs["torch_dtype"] = "auto" |
|
model = super().from_pretrained(*args, **kwargs) |
|
if model.config.pretrained_peft_model_name_or_path: |
|
if os.path.isdir(model.name_or_path): |
|
model.load_adapter(f'{model.name_or_path}/{model.config.pretrained_peft_model_name_or_path}') |
|
else: |
|
adapter_cache_path = snapshot_download( |
|
repo_id=model.name_or_path, |
|
allow_patterns=[os.path.join(model.config.pretrained_peft_model_name_or_path, '*')] |
|
) |
|
adapter_path = os.path.join(adapter_cache_path, model.config.pretrained_peft_model_name_or_path) |
|
model.load_adapter(adapter_path) |
|
return model |
|
|
|
|