# This file is part of a work licensed under the Apache License, Version 2.0. # See the LICENSE file in the root of the original repository: # https://github.com/LLaVA-VL/LLaVA-NeXT?tab=readme-ov-file # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # ----------------------------- Modification Notice ----------------------------- # This file was originally obtained from: # https://github.com/LLaVA-VL/LLaVA-NeXT/blob/4e0ee2d98576210e5a5d122451318d5ef7551fc1/llava/model/multimodal_encoder/siglip_encoder.py#L538-L620 # # Modification by Yusuke Kanebako on 2025-07-22: # - Define the Vision Tower for Qwen2-VL based on the Vision Tower definition of SigLip. import torch import torch.utils.checkpoint from torch import nn from .qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig, Qwen2VLConfig from .qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration from .qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from llava.utils import rank0_print class Qwen2VLVisionTower(nn.Module): def __init__(self, vision_tower, vision_tower_cfg, delay_load=False): super().__init__() self.is_loaded = False self.config = Qwen2VLVisionConfig() self.vision_tower_name = vision_tower self.image_processor = Qwen2VLImageProcessor() if not delay_load: rank0_print(f"Loading vision tower: {vision_tower}") self.load_model() elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False): # TODO: better detector is needed. rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") self.load_model() elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts: rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") self.load_model() else: self.cfg_only = self.config def load_model(self, device_map=None): if self.is_loaded: rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) return self.vision_tower = Qwen2VLForConditionalGeneration.from_pretrained(self.vision_tower_name, device_map=device_map).visual # del self.vision_tower.merger self.vision_tower.requires_grad_(False) self.is_loaded = True def forward(self, images, image_grid_thw=None): if type(images) is list: image_features = [] for image in images: image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), grid_thw=image_grid_thw) image_features.append(image_feature) else: image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype), grid_thw=image_grid_thw) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): for p in self.vision_tower.parameters(): return p.dtype @property def device(self): for p in self.vision_tower.parameters(): return p.device @property def hidden_size(self): return self.config.hidden_size @property def num_patches(self): return (self.config.image_size // self.config.patch_size) ** 2 @property def num_patches_per_side(self): return self.config.image_size // self.config.patch_size # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] @property def image_size(self): return self.config.image_size