Fix data parsing in forward() method (#18)
Browse files- Fix data parsing in forward() method (482d9ffc5660ee228f84613f759e6cf29e053d3d)
Co-authored-by: Jasiek Kostecki <[email protected]>
- modeling_minicpmv.py +24 -0
modeling_minicpmv.py
CHANGED
|
@@ -203,6 +203,30 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 203 |
|
| 204 |
|
| 205 |
def forward(self, data, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
| 207 |
|
| 208 |
position_ids = data["position_ids"]
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
def forward(self, data, **kwargs):
|
| 206 |
+
if isinstance(data, torch.Tensor):
|
| 207 |
+
attention_mask = torch.ones_like(data, dtype=torch.bool)
|
| 208 |
+
kwargs = {'attention_mask': attention_mask}
|
| 209 |
+
return self.llm(
|
| 210 |
+
input_ids=data,
|
| 211 |
+
**kwargs
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if data is None:
|
| 215 |
+
data = {
|
| 216 |
+
"input_ids": kwargs.pop("input_ids", None),
|
| 217 |
+
"pixel_values": kwargs.pop("pixel_values", None),
|
| 218 |
+
"image_bound": kwargs.pop("image_bound", None),
|
| 219 |
+
"tgt_sizes": kwargs.pop("tgt_sizes", None),
|
| 220 |
+
"position_ids": kwargs.pop("position_ids", None),
|
| 221 |
+
}
|
| 222 |
+
else:
|
| 223 |
+
kwargs.pop("input_ids", None)
|
| 224 |
+
kwargs.pop("pixel_values", None)
|
| 225 |
+
kwargs.pop("image_bound", None)
|
| 226 |
+
kwargs.pop("tgt_sizes", None)
|
| 227 |
+
kwargs.pop("position_ids", None)
|
| 228 |
+
kwargs.pop("inputs_embeds", None)
|
| 229 |
+
|
| 230 |
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
| 231 |
|
| 232 |
position_ids = data["position_ids"]
|