tc-mb bialykostek commited on
Commit
343b490
·
verified ·
1 Parent(s): 0e2e644

Fix data parsing in forward() method (#18)

Browse files

- Fix data parsing in forward() method (482d9ffc5660ee228f84613f759e6cf29e053d3d)


Co-authored-by: Jasiek Kostecki <[email protected]>

Files changed (1) hide show
  1. modeling_minicpmv.py +24 -0
modeling_minicpmv.py CHANGED
@@ -203,6 +203,30 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
203
 
204
 
205
  def forward(self, data, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
207
 
208
  position_ids = data["position_ids"]
 
203
 
204
 
205
  def forward(self, data, **kwargs):
206
+ if isinstance(data, torch.Tensor):
207
+ attention_mask = torch.ones_like(data, dtype=torch.bool)
208
+ kwargs = {'attention_mask': attention_mask}
209
+ return self.llm(
210
+ input_ids=data,
211
+ **kwargs
212
+ )
213
+
214
+ if data is None:
215
+ data = {
216
+ "input_ids": kwargs.pop("input_ids", None),
217
+ "pixel_values": kwargs.pop("pixel_values", None),
218
+ "image_bound": kwargs.pop("image_bound", None),
219
+ "tgt_sizes": kwargs.pop("tgt_sizes", None),
220
+ "position_ids": kwargs.pop("position_ids", None),
221
+ }
222
+ else:
223
+ kwargs.pop("input_ids", None)
224
+ kwargs.pop("pixel_values", None)
225
+ kwargs.pop("image_bound", None)
226
+ kwargs.pop("tgt_sizes", None)
227
+ kwargs.pop("position_ids", None)
228
+ kwargs.pop("inputs_embeds", None)
229
+
230
  vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
231
 
232
  position_ids = data["position_ids"]