File size: 6,368 Bytes
835f7bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoProcessor, AutoConfig
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
from tqdm import tqdm
from safetensors.torch import load_file
import os
class Qwen2_5_VL_ImageEncoder:
def __init__(self, model_path: str, device: str = "cuda", dtype=torch.bfloat16):
self.device = device
self.dtype = dtype
print(f"Loading processor and model from {model_path}...")
self.processor = AutoProcessor.from_pretrained("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/Qwen2.5-VL-ViT-Only", trust_remote_code=True)
config = AutoConfig.from_pretrained('/mnt/workspace/workgroup/chx/Qwen2.5-VL-7B-Instruct')
config = config.vision_config
self.model = Qwen2_5_VisionTransformerPretrainedModel(config)
safe_path = os.path.join(model_path, "model.safetensors")
state_dict = load_file(safe_path)
self.model.load_state_dict(state_dict, strict=True)
self.model.to(device=self.device, dtype=self.dtype)
self.model.eval()
print("Model loaded successfully.")
def _process_batch_forward(self, images):
"""Internal helper to run forward pass on a single batch."""
# 1. Prepare Inputs
messages_list = [
[
{"type": "image", "image": img},
{"type": "text", "text": "Describe this image."},
] for img in images
]
# Apply template for each item in the batch
text_inputs = [
self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
for msg in messages_list
]
# Processor handles the batching of pixel values and grids
inputs = self.processor(
images=images,
text=text_inputs,
return_tensors="pt",
padding=True
)
# Move to device
pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype)
grid_thw = inputs["image_grid_thw"].to(self.device)
# 2. Model Forward
outputs = self.model(hidden_states=pixel_values, grid_thw=grid_thw)
hidden_states = outputs
# 3. Pooling Logic (Exact replica of training logic)
if grid_thw.dim() == 3 and grid_thw.size(1) == 1:
grid_thw = grid_thw.squeeze(1)
batch_size = grid_thw.shape[0]
# Calculate tokens per image based on grid dimensions (H//2 * W//2)
H, W = grid_thw[:, 1], grid_thw[:, 2]
sizes = ((H // 2) * (W // 2)).long()
# Safety fix for token mismatch
total_tokens = hidden_states.shape[0]
if sizes.sum().item() != total_tokens:
sizes[-1] += (total_tokens - sizes.sum().item())
# Create batch indices [0,0,0, 1,1, 2,2,2...]
batch_indices = torch.repeat_interleave(
torch.arange(batch_size, device=self.device),
sizes
)
# Sum Pooling
pooled_sum = torch.zeros(
(batch_size, hidden_states.shape[-1]),
dtype=self.dtype,
device=self.device
)
pooled_sum.index_add_(0, batch_indices, hidden_states)
# Mean Pooling
counts = sizes.unsqueeze(1).to(dtype=self.dtype).clamp(min=1.0)
embeds = pooled_sum / counts
# 4. Normalize
embeds = F.normalize(embeds, p=2, dim=-1)
return embeds.cpu() # Move to CPU to save GPU memory during accumulation
@torch.no_grad()
def encode_batch(self, images: list, batch_size: int = 32, show_progress: bool = True):
"""
Args:
images: List of PIL Images.
batch_size: Number of images to process at once.
Returns:
torch.Tensor: Concatenated embeddings [Total_Images, Hidden_Dim]
"""
all_embeddings = []
iterator = range(0, len(images), batch_size)
if show_progress:
iterator = tqdm(iterator, desc="Encoding Batches", unit="batch")
for i in iterator:
batch_images = images[i : i + batch_size]
# Ensure all are RGB
batch_images = [img.convert("RGB") for img in batch_images]
try:
batch_embeds = self._process_batch_forward(batch_images)
all_embeddings.append(batch_embeds)
except Exception as e:
print(f"Error processing batch starting at index {i}: {e}")
# Optional: return partial results or re-raise
raise e
if not all_embeddings:
return torch.empty(0)
# Concatenate all batches into one large tensor
return torch.cat(all_embeddings, dim=0)
# --- Usage Example ---
if __name__ == "__main__":
MODEL_PATHS = [
"/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-500",
"/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-550"]
for MODEL_PATH in MODEL_PATHS:
encoder = Qwen2_5_VL_ImageEncoder(MODEL_PATH)
from datasets import load_dataset
spearmans = []
for lang in ["en","de","es","fr","it","nl","pl","pt","ru","zh"]:
dataset = load_dataset("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/a_eval/stsb",lang)["test"]
anchors = dataset["sentence1"]
positive = dataset["sentence2"]
embeddings1 = encoder.encode_batch(anchors, batch_size=32)
embeddings2 = encoder.encode_batch(positive, batch_size=32)
groundtruth = dataset["score"]
from sklearn.metrics.pairwise import paired_cosine_distances
import numpy as np
from scipy.stats import spearmanr
embeddings1 = embeddings1.cpu().float().numpy()
embeddings2 = embeddings2.cpu().float().numpy()
cos_sim = 1 - paired_cosine_distances(embeddings1, embeddings2)
spearman_corr, _ = spearmanr(cos_sim, groundtruth)
spearmans.append(round(spearman_corr,2))
print("Spearman correlation:", spearmans) |