Update modeling_dots_ocr_vllm.py (#7)
Browse files- Update modeling_dots_ocr_vllm.py (67cbf202a8f58ec138797afa24090a392d608c1e)
- Update modeling_dots_ocr_vllm.py (4bc8bca53381124acd4bf63f4567b515539b3c7f)
Co-authored-by: Renjie Wu <[email protected]>
- modeling_dots_ocr_vllm.py +22 -0
modeling_dots_ocr_vllm.py
CHANGED
@@ -91,6 +91,17 @@ class DotsOCRProcessingInfo(Qwen2_5_VLProcessingInfo):
|
|
91 |
|
92 |
return config
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def get_hf_processor(
|
95 |
self,
|
96 |
*,
|
@@ -99,6 +110,7 @@ class DotsOCRProcessingInfo(Qwen2_5_VLProcessingInfo):
|
|
99 |
size: Optional[dict[str, int]] = None,
|
100 |
**kwargs: object,
|
101 |
) -> Qwen2VLProcessor:
|
|
|
102 |
processor = self.ctx.get_hf_processor(
|
103 |
Qwen2VLProcessor,
|
104 |
image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size),
|
@@ -166,6 +178,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal):
|
|
166 |
)
|
167 |
_tp_plan = {}
|
168 |
|
|
|
|
|
|
|
|
|
|
|
169 |
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
170 |
super().__init__()
|
171 |
|
@@ -409,6 +426,10 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal):
|
|
409 |
|
410 |
|
411 |
def patch_vllm_chat_placeholder():
|
|
|
|
|
|
|
|
|
412 |
from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker
|
413 |
|
414 |
ori = BaseMultiModalItemTracker._placeholder_str
|
@@ -426,4 +447,5 @@ ModelRegistry.register_model(
|
|
426 |
"DotsOCRForCausalLM", DotsOCRForCausalLM,
|
427 |
)
|
428 |
|
|
|
429 |
patch_vllm_chat_placeholder()
|
|
|
91 |
|
92 |
return config
|
93 |
|
94 |
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
95 |
+
return {"image": None, "video": 0}
|
96 |
+
|
97 |
+
def get_mm_max_tokens_per_item(
|
98 |
+
self,
|
99 |
+
seq_len: int,
|
100 |
+
mm_counts: Mapping[str, int],
|
101 |
+
) -> Mapping[str, int]:
|
102 |
+
max_image_tokens = self.get_max_image_tokens()
|
103 |
+
return {"image": max_image_tokens, "video": 0}
|
104 |
+
|
105 |
def get_hf_processor(
|
106 |
self,
|
107 |
*,
|
|
|
110 |
size: Optional[dict[str, int]] = None,
|
111 |
**kwargs: object,
|
112 |
) -> Qwen2VLProcessor:
|
113 |
+
self.get_tokenizer().image_token = "<|imgpad|>" # Ensure image token is set
|
114 |
processor = self.ctx.get_hf_processor(
|
115 |
Qwen2VLProcessor,
|
116 |
image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size),
|
|
|
178 |
)
|
179 |
_tp_plan = {}
|
180 |
|
181 |
+
@classmethod
|
182 |
+
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
183 |
+
if modality in ("image",):
|
184 |
+
return "<|img|><|imgpad|><|endofimg|>"
|
185 |
+
|
186 |
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
187 |
super().__init__()
|
188 |
|
|
|
426 |
|
427 |
|
428 |
def patch_vllm_chat_placeholder():
|
429 |
+
import vllm
|
430 |
+
# return when vllm version > 0.9.1
|
431 |
+
if not (vllm.__version_tuple__[0]==0 and vllm.__version_tuple__[1] <= 9 and vllm.__version_tuple__[2] <= 1):
|
432 |
+
return
|
433 |
from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker
|
434 |
|
435 |
ori = BaseMultiModalItemTracker._placeholder_str
|
|
|
447 |
"DotsOCRForCausalLM", DotsOCRForCausalLM,
|
448 |
)
|
449 |
|
450 |
+
|
451 |
patch_vllm_chat_placeholder()
|