| | |
| | import torch |
| | from typing import List, Dict, Any, Union |
| | from PIL import Image |
| | from transformers.processing_utils import ProcessorMixin, BatchFeature |
| | from transformers import AutoTokenizer, AutoImageProcessor |
| |
|
| | PLACEHOLDER = "<|media_placeholder|>" |
| |
|
| | class OpenCUAProcessor(ProcessorMixin): |
| | attributes = ["image_processor", "tokenizer", "image_token_id", "merge_size"] |
| |
|
| | def __init__(self, image_processor, tokenizer, image_token_id: int = 151664, merge_size: int = 2, **kwargs): |
| | self.image_processor = image_processor |
| | self.tokenizer = tokenizer |
| | self.image_token_id = image_token_id |
| | self.merge_size = getattr(image_processor, "merge_size", merge_size) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| | trust = kwargs.get("trust_remote_code", True) |
| | |
| | try: |
| | from tokenization_opencua import TikTokenV3 |
| | tok = TikTokenV3.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust) |
| | except Exception: |
| | tok = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust) |
| | imgproc = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust) |
| | return cls(imgproc, tok, **kwargs) |
| |
|
| | def apply_chat_template(self, messages: List[Dict[str, Any]], **kwargs) -> Union[str, List[int]]: |
| | return self.tokenizer.apply_chat_template(messages, **kwargs) |
| |
|
| | |
| | def __call__(self, *args, **kwargs) -> BatchFeature: |
| | |
| | data = {"input_ids": torch.zeros(1, 1, dtype=torch.long)} |
| | return BatchFeature(data=data) |
| |
|
| | |
| | def prepare_vllm_inputs(self, messages, images, add_generation_prompt=True): |
| | text = self.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt) |
| | proc = self.image_processor(images=images, return_tensors="pt") |
| | grid = torch.as_tensor(proc["image_grid_thw"]) |
| | merge = getattr(self, "merge_size", 2) |
| | for thw in grid: |
| | num = int((thw[0] * thw[1] * thw[2]) // (merge ** 2)) |
| | text = text.replace(PLACEHOLDER, PLACEHOLDER * num, 1) |
| | return text, images |
| |
|
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|