File size: 4,981 Bytes
e1cb4af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import List, Optional, Tuple, Union

import torch
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen2 import Qwen2ForCausalLM

from .configuration_dots import DotsVisionConfig, DotsOCRConfig
from .modeling_dots_vision import DotsVisionTransformer


DOTS_VLM_MAX_IMAGES = 200


class DotsOCRForCausalLM(Qwen2ForCausalLM):
    config_class = DotsOCRConfig

    def __init__(self, config: DotsOCRConfig):
        super().__init__(config)

        if isinstance(self.config.vision_config, dict):
            vision_config = DotsVisionConfig(**self.config.vision_config)
            self.config.vision_config = vision_config
        else:
            vision_config = self.config.vision_config

        self.vision_tower = DotsVisionTransformer(vision_config)

    def prepare_inputs_embeds(
        self,
        input_ids: torch.LongTensor,
        pixel_values: Optional[torch.FloatTensor] = None,
        grid_thw: Optional[torch.FloatTensor] = None,
        img_mask: Optional[torch.BoolTensor] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.get_input_embeddings()(input_ids)

        if pixel_values is not None:
            assert img_mask is not None
            if grid_thw.shape[0] > DOTS_VLM_MAX_IMAGES:
                print(
                    f"Num image exceeded: {grid_thw.shape[0]} > {DOTS_VLM_MAX_IMAGES}, which may cause FSDP hang"
                )

            vision_embeddings = self.vision_tower(pixel_values, grid_thw)

            true_indices = torch.nonzero(img_mask).squeeze()
            if len(true_indices) > vision_embeddings.size(0):
                print(
                    f"img_mask sum > VE and will be truncated, mask.sum()={len(true_indices)} {vision_embeddings.size(0)=}"
                )
                true_indices = true_indices[: vision_embeddings.size(0)]
                new_img_mask = torch.zeros_like(img_mask, device=img_mask.device)
                new_img_mask[true_indices[:, 0], true_indices[:, 1]] = True
            else:
                new_img_mask = img_mask

            assert (
                vision_embeddings.size(0) == new_img_mask.sum()
            ), f"{vision_embeddings.size(0)=}, {new_img_mask.sum()=}"

            inputs_embeds = inputs_embeds.masked_scatter(
                new_img_mask.to(inputs_embeds.device).unsqueeze(-1).expand_as(inputs_embeds),
                vision_embeddings.to(inputs_embeds.device).type(inputs_embeds.dtype),
            )

        return inputs_embeds

    def forward(
        self,
        input_ids: torch.LongTensor,
        pixel_values: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        use_cache: Optional[bool] = None,
        logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        assert len(input_ids) >= 1, f"empty input_ids {input_ids.shape=} will cause gradnorm nan"
        if inputs_embeds is None:
            img_mask = input_ids == self.config.image_token_id
            inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, image_grid_thw, img_mask)

        outputs = super().forward(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            labels=labels,
            use_cache=use_cache if use_cache is not None else self.config.use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            # return_dict=return_dict,
            logits_to_keep=logits_to_keep,
            **loss_kwargs,
        )

        return outputs

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        inputs_embeds=None,
        pixel_values=None,
        attention_mask=None,
        cache_position=None,
        num_logits_to_keep=None,
        **kwargs,
    ):
        model_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            num_logits_to_keep=num_logits_to_keep,
            **kwargs,
        )

        if cache_position[0] == 0:
            model_inputs["pixel_values"] = pixel_values

        return model_inputs