pixel_vit_test / example_inference.py
gowitheflow's picture
Upload full checkpoint folder
835f7bd verified
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoProcessor, AutoConfig
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
from tqdm import tqdm
from safetensors.torch import load_file
import os
class Qwen2_5_VL_ImageEncoder:
def __init__(self, model_path: str, device: str = "cuda", dtype=torch.bfloat16):
self.device = device
self.dtype = dtype
print(f"Loading processor and model from {model_path}...")
self.processor = AutoProcessor.from_pretrained("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/Qwen2.5-VL-ViT-Only", trust_remote_code=True)
config = AutoConfig.from_pretrained('/mnt/workspace/workgroup/chx/Qwen2.5-VL-7B-Instruct')
config = config.vision_config
self.model = Qwen2_5_VisionTransformerPretrainedModel(config)
safe_path = os.path.join(model_path, "model.safetensors")
state_dict = load_file(safe_path)
self.model.load_state_dict(state_dict, strict=True)
self.model.to(device=self.device, dtype=self.dtype)
self.model.eval()
print("Model loaded successfully.")
def _process_batch_forward(self, images):
"""Internal helper to run forward pass on a single batch."""
# 1. Prepare Inputs
messages_list = [
[
{"type": "image", "image": img},
{"type": "text", "text": "Describe this image."},
] for img in images
]
# Apply template for each item in the batch
text_inputs = [
self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
for msg in messages_list
]
# Processor handles the batching of pixel values and grids
inputs = self.processor(
images=images,
text=text_inputs,
return_tensors="pt",
padding=True
)
# Move to device
pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype)
grid_thw = inputs["image_grid_thw"].to(self.device)
# 2. Model Forward
outputs = self.model(hidden_states=pixel_values, grid_thw=grid_thw)
hidden_states = outputs
# 3. Pooling Logic (Exact replica of training logic)
if grid_thw.dim() == 3 and grid_thw.size(1) == 1:
grid_thw = grid_thw.squeeze(1)
batch_size = grid_thw.shape[0]
# Calculate tokens per image based on grid dimensions (H//2 * W//2)
H, W = grid_thw[:, 1], grid_thw[:, 2]
sizes = ((H // 2) * (W // 2)).long()
# Safety fix for token mismatch
total_tokens = hidden_states.shape[0]
if sizes.sum().item() != total_tokens:
sizes[-1] += (total_tokens - sizes.sum().item())
# Create batch indices [0,0,0, 1,1, 2,2,2...]
batch_indices = torch.repeat_interleave(
torch.arange(batch_size, device=self.device),
sizes
)
# Sum Pooling
pooled_sum = torch.zeros(
(batch_size, hidden_states.shape[-1]),
dtype=self.dtype,
device=self.device
)
pooled_sum.index_add_(0, batch_indices, hidden_states)
# Mean Pooling
counts = sizes.unsqueeze(1).to(dtype=self.dtype).clamp(min=1.0)
embeds = pooled_sum / counts
# 4. Normalize
embeds = F.normalize(embeds, p=2, dim=-1)
return embeds.cpu() # Move to CPU to save GPU memory during accumulation
@torch.no_grad()
def encode_batch(self, images: list, batch_size: int = 32, show_progress: bool = True):
"""
Args:
images: List of PIL Images.
batch_size: Number of images to process at once.
Returns:
torch.Tensor: Concatenated embeddings [Total_Images, Hidden_Dim]
"""
all_embeddings = []
iterator = range(0, len(images), batch_size)
if show_progress:
iterator = tqdm(iterator, desc="Encoding Batches", unit="batch")
for i in iterator:
batch_images = images[i : i + batch_size]
# Ensure all are RGB
batch_images = [img.convert("RGB") for img in batch_images]
try:
batch_embeds = self._process_batch_forward(batch_images)
all_embeddings.append(batch_embeds)
except Exception as e:
print(f"Error processing batch starting at index {i}: {e}")
# Optional: return partial results or re-raise
raise e
if not all_embeddings:
return torch.empty(0)
# Concatenate all batches into one large tensor
return torch.cat(all_embeddings, dim=0)
# --- Usage Example ---
if __name__ == "__main__":
MODEL_PATHS = [
"/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-500",
"/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-550"]
for MODEL_PATH in MODEL_PATHS:
encoder = Qwen2_5_VL_ImageEncoder(MODEL_PATH)
from datasets import load_dataset
spearmans = []
for lang in ["en","de","es","fr","it","nl","pl","pt","ru","zh"]:
dataset = load_dataset("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/a_eval/stsb",lang)["test"]
anchors = dataset["sentence1"]
positive = dataset["sentence2"]
embeddings1 = encoder.encode_batch(anchors, batch_size=32)
embeddings2 = encoder.encode_batch(positive, batch_size=32)
groundtruth = dataset["score"]
from sklearn.metrics.pairwise import paired_cosine_distances
import numpy as np
from scipy.stats import spearmanr
embeddings1 = embeddings1.cpu().float().numpy()
embeddings2 = embeddings2.cpu().float().numpy()
cos_sim = 1 - paired_cosine_distances(embeddings1, embeddings2)
spearman_corr, _ = spearmanr(cos_sim, groundtruth)
spearmans.append(round(spearman_corr,2))
print("Spearman correlation:", spearmans)