|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor, AutoConfig |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel |
|
|
from tqdm import tqdm |
|
|
from safetensors.torch import load_file |
|
|
import os |
|
|
|
|
|
class Qwen2_5_VL_ImageEncoder: |
|
|
def __init__(self, model_path: str, device: str = "cuda", dtype=torch.bfloat16): |
|
|
self.device = device |
|
|
self.dtype = dtype |
|
|
|
|
|
print(f"Loading processor and model from {model_path}...") |
|
|
self.processor = AutoProcessor.from_pretrained("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/Qwen2.5-VL-ViT-Only", trust_remote_code=True) |
|
|
|
|
|
config = AutoConfig.from_pretrained('/mnt/workspace/workgroup/chx/Qwen2.5-VL-7B-Instruct') |
|
|
config = config.vision_config |
|
|
|
|
|
self.model = Qwen2_5_VisionTransformerPretrainedModel(config) |
|
|
|
|
|
safe_path = os.path.join(model_path, "model.safetensors") |
|
|
state_dict = load_file(safe_path) |
|
|
self.model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
self.model.to(device=self.device, dtype=self.dtype) |
|
|
self.model.eval() |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
def _process_batch_forward(self, images): |
|
|
"""Internal helper to run forward pass on a single batch.""" |
|
|
|
|
|
messages_list = [ |
|
|
[ |
|
|
{"type": "image", "image": img}, |
|
|
{"type": "text", "text": "Describe this image."}, |
|
|
] for img in images |
|
|
] |
|
|
|
|
|
|
|
|
text_inputs = [ |
|
|
self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) |
|
|
for msg in messages_list |
|
|
] |
|
|
|
|
|
|
|
|
inputs = self.processor( |
|
|
images=images, |
|
|
text=text_inputs, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
|
|
|
|
|
|
pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype) |
|
|
grid_thw = inputs["image_grid_thw"].to(self.device) |
|
|
|
|
|
|
|
|
outputs = self.model(hidden_states=pixel_values, grid_thw=grid_thw) |
|
|
hidden_states = outputs |
|
|
|
|
|
|
|
|
if grid_thw.dim() == 3 and grid_thw.size(1) == 1: |
|
|
grid_thw = grid_thw.squeeze(1) |
|
|
|
|
|
batch_size = grid_thw.shape[0] |
|
|
|
|
|
|
|
|
H, W = grid_thw[:, 1], grid_thw[:, 2] |
|
|
sizes = ((H // 2) * (W // 2)).long() |
|
|
|
|
|
|
|
|
total_tokens = hidden_states.shape[0] |
|
|
if sizes.sum().item() != total_tokens: |
|
|
sizes[-1] += (total_tokens - sizes.sum().item()) |
|
|
|
|
|
|
|
|
batch_indices = torch.repeat_interleave( |
|
|
torch.arange(batch_size, device=self.device), |
|
|
sizes |
|
|
) |
|
|
|
|
|
|
|
|
pooled_sum = torch.zeros( |
|
|
(batch_size, hidden_states.shape[-1]), |
|
|
dtype=self.dtype, |
|
|
device=self.device |
|
|
) |
|
|
pooled_sum.index_add_(0, batch_indices, hidden_states) |
|
|
|
|
|
|
|
|
counts = sizes.unsqueeze(1).to(dtype=self.dtype).clamp(min=1.0) |
|
|
embeds = pooled_sum / counts |
|
|
|
|
|
|
|
|
embeds = F.normalize(embeds, p=2, dim=-1) |
|
|
|
|
|
return embeds.cpu() |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_batch(self, images: list, batch_size: int = 32, show_progress: bool = True): |
|
|
""" |
|
|
Args: |
|
|
images: List of PIL Images. |
|
|
batch_size: Number of images to process at once. |
|
|
Returns: |
|
|
torch.Tensor: Concatenated embeddings [Total_Images, Hidden_Dim] |
|
|
""" |
|
|
all_embeddings = [] |
|
|
|
|
|
iterator = range(0, len(images), batch_size) |
|
|
if show_progress: |
|
|
iterator = tqdm(iterator, desc="Encoding Batches", unit="batch") |
|
|
|
|
|
for i in iterator: |
|
|
batch_images = images[i : i + batch_size] |
|
|
|
|
|
|
|
|
batch_images = [img.convert("RGB") for img in batch_images] |
|
|
|
|
|
try: |
|
|
batch_embeds = self._process_batch_forward(batch_images) |
|
|
all_embeddings.append(batch_embeds) |
|
|
except Exception as e: |
|
|
print(f"Error processing batch starting at index {i}: {e}") |
|
|
|
|
|
raise e |
|
|
|
|
|
if not all_embeddings: |
|
|
return torch.empty(0) |
|
|
|
|
|
|
|
|
return torch.cat(all_embeddings, dim=0) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
MODEL_PATHS = [ |
|
|
"/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-500", |
|
|
"/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-550"] |
|
|
for MODEL_PATH in MODEL_PATHS: |
|
|
encoder = Qwen2_5_VL_ImageEncoder(MODEL_PATH) |
|
|
|
|
|
from datasets import load_dataset |
|
|
spearmans = [] |
|
|
for lang in ["en","de","es","fr","it","nl","pl","pt","ru","zh"]: |
|
|
dataset = load_dataset("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/a_eval/stsb",lang)["test"] |
|
|
anchors = dataset["sentence1"] |
|
|
positive = dataset["sentence2"] |
|
|
|
|
|
embeddings1 = encoder.encode_batch(anchors, batch_size=32) |
|
|
embeddings2 = encoder.encode_batch(positive, batch_size=32) |
|
|
groundtruth = dataset["score"] |
|
|
|
|
|
|
|
|
from sklearn.metrics.pairwise import paired_cosine_distances |
|
|
import numpy as np |
|
|
from scipy.stats import spearmanr |
|
|
|
|
|
embeddings1 = embeddings1.cpu().float().numpy() |
|
|
embeddings2 = embeddings2.cpu().float().numpy() |
|
|
|
|
|
cos_sim = 1 - paired_cosine_distances(embeddings1, embeddings2) |
|
|
spearman_corr, _ = spearmanr(cos_sim, groundtruth) |
|
|
spearmans.append(round(spearman_corr,2)) |
|
|
print("Spearman correlation:", spearmans) |