YC-Chen commited on
Commit
bfffb35
·
verified ·
1 Parent(s): d0b8862

Update modeling_internvl_chat.py

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +2 -2
modeling_internvl_chat.py CHANGED
@@ -38,7 +38,7 @@ class InternVLChatModel(PreTrainedModel):
38
  _supports_flash_attn_2 = True
39
  _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer', 'MistralDecoderLayer']
40
 
41
- def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
42
  super().__init__(config)
43
 
44
  assert version_cmp(transformers.__version__, '4.37.0', 'ge')
@@ -81,7 +81,7 @@ class InternVLChatModel(PreTrainedModel):
81
  nn.Linear(llm_hidden_size, llm_hidden_size)
82
  )
83
 
84
- self.img_context_token_id = None
85
  self.mr_prompt = MRPromptV3()
86
 
87
  def forward(
 
38
  _supports_flash_attn_2 = True
39
  _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer', 'MistralDecoderLayer']
40
 
41
+ def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True, img_context_token_id=None):
42
  super().__init__(config)
43
 
44
  assert version_cmp(transformers.__version__, '4.37.0', 'ge')
 
81
  nn.Linear(llm_hidden_size, llm_hidden_size)
82
  )
83
 
84
+ self.img_context_token_id = img_context_token_id
85
  self.mr_prompt = MRPromptV3()
86
 
87
  def forward(