import torch import torch.nn.functional as F from PIL import Image from transformers import AutoProcessor, AutoConfig from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel from tqdm import tqdm from safetensors.torch import load_file import os class Qwen2_5_VL_ImageEncoder: def __init__(self, model_path: str, device: str = "cuda", dtype=torch.bfloat16): self.device = device self.dtype = dtype print(f"Loading processor and model from {model_path}...") self.processor = AutoProcessor.from_pretrained("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/Qwen2.5-VL-ViT-Only", trust_remote_code=True) config = AutoConfig.from_pretrained('/mnt/workspace/workgroup/chx/Qwen2.5-VL-7B-Instruct') config = config.vision_config self.model = Qwen2_5_VisionTransformerPretrainedModel(config) safe_path = os.path.join(model_path, "model.safetensors") state_dict = load_file(safe_path) self.model.load_state_dict(state_dict, strict=True) self.model.to(device=self.device, dtype=self.dtype) self.model.eval() print("Model loaded successfully.") def _process_batch_forward(self, images): """Internal helper to run forward pass on a single batch.""" # 1. Prepare Inputs messages_list = [ [ {"type": "image", "image": img}, {"type": "text", "text": "Describe this image."}, ] for img in images ] # Apply template for each item in the batch text_inputs = [ self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages_list ] # Processor handles the batching of pixel values and grids inputs = self.processor( images=images, text=text_inputs, return_tensors="pt", padding=True ) # Move to device pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype) grid_thw = inputs["image_grid_thw"].to(self.device) # 2. Model Forward outputs = self.model(hidden_states=pixel_values, grid_thw=grid_thw) hidden_states = outputs # 3. Pooling Logic (Exact replica of training logic) if grid_thw.dim() == 3 and grid_thw.size(1) == 1: grid_thw = grid_thw.squeeze(1) batch_size = grid_thw.shape[0] # Calculate tokens per image based on grid dimensions (H//2 * W//2) H, W = grid_thw[:, 1], grid_thw[:, 2] sizes = ((H // 2) * (W // 2)).long() # Safety fix for token mismatch total_tokens = hidden_states.shape[0] if sizes.sum().item() != total_tokens: sizes[-1] += (total_tokens - sizes.sum().item()) # Create batch indices [0,0,0, 1,1, 2,2,2...] batch_indices = torch.repeat_interleave( torch.arange(batch_size, device=self.device), sizes ) # Sum Pooling pooled_sum = torch.zeros( (batch_size, hidden_states.shape[-1]), dtype=self.dtype, device=self.device ) pooled_sum.index_add_(0, batch_indices, hidden_states) # Mean Pooling counts = sizes.unsqueeze(1).to(dtype=self.dtype).clamp(min=1.0) embeds = pooled_sum / counts # 4. Normalize embeds = F.normalize(embeds, p=2, dim=-1) return embeds.cpu() # Move to CPU to save GPU memory during accumulation @torch.no_grad() def encode_batch(self, images: list, batch_size: int = 32, show_progress: bool = True): """ Args: images: List of PIL Images. batch_size: Number of images to process at once. Returns: torch.Tensor: Concatenated embeddings [Total_Images, Hidden_Dim] """ all_embeddings = [] iterator = range(0, len(images), batch_size) if show_progress: iterator = tqdm(iterator, desc="Encoding Batches", unit="batch") for i in iterator: batch_images = images[i : i + batch_size] # Ensure all are RGB batch_images = [img.convert("RGB") for img in batch_images] try: batch_embeds = self._process_batch_forward(batch_images) all_embeddings.append(batch_embeds) except Exception as e: print(f"Error processing batch starting at index {i}: {e}") # Optional: return partial results or re-raise raise e if not all_embeddings: return torch.empty(0) # Concatenate all batches into one large tensor return torch.cat(all_embeddings, dim=0) # --- Usage Example --- if __name__ == "__main__": MODEL_PATHS = [ "/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-500", "/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-550"] for MODEL_PATH in MODEL_PATHS: encoder = Qwen2_5_VL_ImageEncoder(MODEL_PATH) from datasets import load_dataset spearmans = [] for lang in ["en","de","es","fr","it","nl","pl","pt","ru","zh"]: dataset = load_dataset("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/a_eval/stsb",lang)["test"] anchors = dataset["sentence1"] positive = dataset["sentence2"] embeddings1 = encoder.encode_batch(anchors, batch_size=32) embeddings2 = encoder.encode_batch(positive, batch_size=32) groundtruth = dataset["score"] from sklearn.metrics.pairwise import paired_cosine_distances import numpy as np from scipy.stats import spearmanr embeddings1 = embeddings1.cpu().float().numpy() embeddings2 = embeddings2.cpu().float().numpy() cos_sim = 1 - paired_cosine_distances(embeddings1, embeddings2) spearman_corr, _ = spearmanr(cos_sim, groundtruth) spearmans.append(round(spearman_corr,2)) print("Spearman correlation:", spearmans)