File size: 6,368 Bytes
835f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoProcessor, AutoConfig
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
from tqdm import tqdm 
from safetensors.torch import load_file
import os

class Qwen2_5_VL_ImageEncoder:
    def __init__(self, model_path: str, device: str = "cuda", dtype=torch.bfloat16):
        self.device = device
        self.dtype = dtype
        
        print(f"Loading processor and model from {model_path}...")
        self.processor = AutoProcessor.from_pretrained("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/Qwen2.5-VL-ViT-Only", trust_remote_code=True)
        
        config = AutoConfig.from_pretrained('/mnt/workspace/workgroup/chx/Qwen2.5-VL-7B-Instruct')
        config = config.vision_config
        
        self.model = Qwen2_5_VisionTransformerPretrainedModel(config)
        
        safe_path = os.path.join(model_path, "model.safetensors")
        state_dict = load_file(safe_path)
        self.model.load_state_dict(state_dict, strict=True)

        self.model.to(device=self.device, dtype=self.dtype)
        self.model.eval()
        print("Model loaded successfully.")
        
    def _process_batch_forward(self, images):
        """Internal helper to run forward pass on a single batch."""
        # 1. Prepare Inputs
        messages_list = [
            [
                {"type": "image", "image": img},
                {"type": "text", "text": "Describe this image."},
            ] for img in images
        ]
        
        # Apply template for each item in the batch
        text_inputs = [
            self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) 
            for msg in messages_list
        ]
        
        # Processor handles the batching of pixel values and grids
        inputs = self.processor(
            images=images,
            text=text_inputs,
            return_tensors="pt",
            padding=True
        )
        
        # Move to device
        pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype)
        grid_thw = inputs["image_grid_thw"].to(self.device)

        # 2. Model Forward
        outputs = self.model(hidden_states=pixel_values, grid_thw=grid_thw)
        hidden_states = outputs

        # 3. Pooling Logic (Exact replica of training logic)
        if grid_thw.dim() == 3 and grid_thw.size(1) == 1:
            grid_thw = grid_thw.squeeze(1)

        batch_size = grid_thw.shape[0]
        
        # Calculate tokens per image based on grid dimensions (H//2 * W//2)
        H, W = grid_thw[:, 1], grid_thw[:, 2]
        sizes = ((H // 2) * (W // 2)).long()
        
        # Safety fix for token mismatch
        total_tokens = hidden_states.shape[0]
        if sizes.sum().item() != total_tokens:
            sizes[-1] += (total_tokens - sizes.sum().item())

        # Create batch indices [0,0,0, 1,1, 2,2,2...]
        batch_indices = torch.repeat_interleave(
            torch.arange(batch_size, device=self.device), 
            sizes
        )

        # Sum Pooling
        pooled_sum = torch.zeros(
            (batch_size, hidden_states.shape[-1]), 
            dtype=self.dtype, 
            device=self.device
        )
        pooled_sum.index_add_(0, batch_indices, hidden_states)

        # Mean Pooling
        counts = sizes.unsqueeze(1).to(dtype=self.dtype).clamp(min=1.0)
        embeds = pooled_sum / counts

        # 4. Normalize
        embeds = F.normalize(embeds, p=2, dim=-1)
        
        return embeds.cpu() # Move to CPU to save GPU memory during accumulation

    @torch.no_grad()
    def encode_batch(self, images: list, batch_size: int = 32, show_progress: bool = True):
        """
        Args:
            images: List of PIL Images.
            batch_size: Number of images to process at once.
        Returns:
            torch.Tensor: Concatenated embeddings [Total_Images, Hidden_Dim]
        """
        all_embeddings = []
        
        iterator = range(0, len(images), batch_size)
        if show_progress:
            iterator = tqdm(iterator, desc="Encoding Batches", unit="batch")

        for i in iterator:
            batch_images = images[i : i + batch_size]
            
            # Ensure all are RGB
            batch_images = [img.convert("RGB") for img in batch_images]
            
            try:
                batch_embeds = self._process_batch_forward(batch_images)
                all_embeddings.append(batch_embeds)
            except Exception as e:
                print(f"Error processing batch starting at index {i}: {e}")
                # Optional: return partial results or re-raise
                raise e

        if not all_embeddings:
            return torch.empty(0)

        # Concatenate all batches into one large tensor
        return torch.cat(all_embeddings, dim=0)

# --- Usage Example ---
if __name__ == "__main__":

    MODEL_PATHS = [
    "/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-500",
    "/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-550"]
    for MODEL_PATH in MODEL_PATHS:
        encoder = Qwen2_5_VL_ImageEncoder(MODEL_PATH)

        from datasets import load_dataset
        spearmans = []
        for lang in ["en","de","es","fr","it","nl","pl","pt","ru","zh"]:
            dataset = load_dataset("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/a_eval/stsb",lang)["test"]
            anchors = dataset["sentence1"]
            positive = dataset["sentence2"]

            embeddings1 = encoder.encode_batch(anchors, batch_size=32)
            embeddings2 = encoder.encode_batch(positive, batch_size=32)
            groundtruth = dataset["score"]


            from sklearn.metrics.pairwise import paired_cosine_distances
            import numpy as np
            from scipy.stats import spearmanr

            embeddings1 = embeddings1.cpu().float().numpy()
            embeddings2 = embeddings2.cpu().float().numpy()

            cos_sim = 1 - paired_cosine_distances(embeddings1, embeddings2)
            spearman_corr, _ = spearmanr(cos_sim, groundtruth)
            spearmans.append(round(spearman_corr,2))
        print("Spearman correlation:", spearmans)