Update modeling_internvl_chat.py
Browse files
modeling_internvl_chat.py
CHANGED
@@ -38,7 +38,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
38 |
_supports_flash_attn_2 = True
|
39 |
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer', 'MistralDecoderLayer']
|
40 |
|
41 |
-
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
|
42 |
super().__init__(config)
|
43 |
|
44 |
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
|
@@ -81,7 +81,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
81 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
82 |
)
|
83 |
|
84 |
-
self.img_context_token_id =
|
85 |
self.mr_prompt = MRPromptV3()
|
86 |
|
87 |
def forward(
|
|
|
38 |
_supports_flash_attn_2 = True
|
39 |
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer', 'MistralDecoderLayer']
|
40 |
|
41 |
+
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True, img_context_token_id=None):
|
42 |
super().__init__(config)
|
43 |
|
44 |
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
|
|
|
81 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
82 |
)
|
83 |
|
84 |
+
self.img_context_token_id = img_context_token_id
|
85 |
self.mr_prompt = MRPromptV3()
|
86 |
|
87 |
def forward(
|