File size: 4,133 Bytes
851fb8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# This file is part of a work licensed under the Apache License, Version 2.0.
# See the LICENSE file in the root of the original repository:
# https://github.com/LLaVA-VL/LLaVA-NeXT?tab=readme-ov-file
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
# ----------------------------- Modification Notice -----------------------------
# This file was originally obtained from:
# https://github.com/LLaVA-VL/LLaVA-NeXT/blob/4e0ee2d98576210e5a5d122451318d5ef7551fc1/llava/model/multimodal_encoder/siglip_encoder.py#L538-L620
#
# Modification by Yusuke Kanebako on 2025-07-22:
# - Define the Vision Tower for Qwen2-VL based on the Vision Tower definition of SigLip.

import torch
import torch.utils.checkpoint
from torch import nn

from .qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig, Qwen2VLConfig
from .qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
from .qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor

from llava.utils import rank0_print


class Qwen2VLVisionTower(nn.Module):
    def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.config = Qwen2VLVisionConfig()

        self.vision_tower_name = vision_tower

        self.image_processor = Qwen2VLImageProcessor()

        if not delay_load:
            rank0_print(f"Loading vision tower: {vision_tower}")
            self.load_model()
        elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
            # TODO: better detector is needed.
            rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
            self.load_model()
        elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
            rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
            self.load_model()
        else:
            self.cfg_only = self.config

    def load_model(self, device_map=None):
        if self.is_loaded:
            rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
            return

        self.vision_tower = Qwen2VLForConditionalGeneration.from_pretrained(self.vision_tower_name, device_map=device_map).visual
        # del self.vision_tower.merger

        self.vision_tower.requires_grad_(False)

        self.is_loaded = True

    def forward(self, images, image_grid_thw=None):
        if type(images) is list:
            image_features = []
            for image in images:
                image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), grid_thw=image_grid_thw)
                image_features.append(image_feature)
        else:
            image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype), grid_thw=image_grid_thw)
        return image_features

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        for p in self.vision_tower.parameters():
            return p.dtype

    @property
    def device(self):
        for p in self.vision_tower.parameters():
            return p.device

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def num_patches(self):
        return (self.config.image_size // self.config.patch_size) ** 2

    @property
    def num_patches_per_side(self):
        return self.config.image_size // self.config.patch_size
        # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]

    @property
    def image_size(self):
        return self.config.image_size