Upload model
Browse files- modeling.py +66 -6
modeling.py
CHANGED
@@ -155,7 +155,9 @@ class EntityFusionLayer(nn.Module):
|
|
155 |
|
156 |
|
157 |
class KPRMixin:
|
158 |
-
def _forward(
|
|
|
|
|
159 |
return_dict = inputs.pop("return_dict", True)
|
160 |
|
161 |
if self.training:
|
@@ -185,7 +187,7 @@ class KPRMixin:
|
|
185 |
else:
|
186 |
return (sentence_embeddings,)
|
187 |
|
188 |
-
def encode(self, **inputs:
|
189 |
entity_ids = inputs.pop("entity_ids", None)
|
190 |
entity_position_ids = inputs.pop("entity_position_ids", None)
|
191 |
entity_embeds = inputs.pop("entity_embeds", None)
|
@@ -231,8 +233,37 @@ class KPRModelForBert(BertPreTrainedModel, KPRMixin):
|
|
231 |
|
232 |
self.post_init()
|
233 |
|
234 |
-
def forward(
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
|
238 |
class KPRModelForXLMRoberta(XLMRobertaPreTrainedModel, KPRMixin):
|
@@ -247,5 +278,34 @@ class KPRModelForXLMRoberta(XLMRobertaPreTrainedModel, KPRMixin):
|
|
247 |
|
248 |
self.post_init()
|
249 |
|
250 |
-
def forward(
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
|
157 |
class KPRMixin:
|
158 |
+
def _forward(
|
159 |
+
self, **inputs: Tensor | bool | None | dict[str, Tensor | bool | None]
|
160 |
+
) -> tuple[Tensor] | tuple[Tensor, Tensor] | ModelOutput:
|
161 |
return_dict = inputs.pop("return_dict", True)
|
162 |
|
163 |
if self.training:
|
|
|
187 |
else:
|
188 |
return (sentence_embeddings,)
|
189 |
|
190 |
+
def encode(self, **inputs: Tensor | bool | None) -> Tensor:
|
191 |
entity_ids = inputs.pop("entity_ids", None)
|
192 |
entity_position_ids = inputs.pop("entity_position_ids", None)
|
193 |
entity_embeds = inputs.pop("entity_embeds", None)
|
|
|
233 |
|
234 |
self.post_init()
|
235 |
|
236 |
+
def forward(
|
237 |
+
self,
|
238 |
+
input_ids: torch.Tensor | None = None,
|
239 |
+
attention_mask: torch.Tensor | None = None,
|
240 |
+
token_type_ids: torch.Tensor | None = None,
|
241 |
+
position_ids: torch.Tensor | None = None,
|
242 |
+
head_mask: torch.Tensor | None = None,
|
243 |
+
inputs_embeds: torch.Tensor | None = None,
|
244 |
+
entity_ids: torch.Tensor | None = None,
|
245 |
+
entity_position_ids: torch.Tensor | None = None,
|
246 |
+
entity_embeds: torch.Tensor | None = None,
|
247 |
+
output_attentions: bool | None = None,
|
248 |
+
output_hidden_states: bool | None = None,
|
249 |
+
return_dict: bool | None = None,
|
250 |
+
**kwargs
|
251 |
+
):
|
252 |
+
return self._forward(
|
253 |
+
input_ids=input_ids,
|
254 |
+
attention_mask=attention_mask,
|
255 |
+
token_type_ids=token_type_ids,
|
256 |
+
position_ids=position_ids,
|
257 |
+
head_mask=head_mask,
|
258 |
+
inputs_embeds=inputs_embeds,
|
259 |
+
entity_ids=entity_ids,
|
260 |
+
entity_position_ids=entity_position_ids,
|
261 |
+
entity_embeds=entity_embeds,
|
262 |
+
output_attentions=output_attentions,
|
263 |
+
output_hidden_states=output_hidden_states,
|
264 |
+
return_dict=return_dict,
|
265 |
+
**kwargs
|
266 |
+
)
|
267 |
|
268 |
|
269 |
class KPRModelForXLMRoberta(XLMRobertaPreTrainedModel, KPRMixin):
|
|
|
278 |
|
279 |
self.post_init()
|
280 |
|
281 |
+
def forward(
|
282 |
+
self,
|
283 |
+
input_ids: torch.Tensor | None = None,
|
284 |
+
attention_mask: torch.Tensor | None = None,
|
285 |
+
token_type_ids: torch.Tensor | None = None,
|
286 |
+
position_ids: torch.Tensor | None = None,
|
287 |
+
head_mask: torch.Tensor | None = None,
|
288 |
+
inputs_embeds: torch.Tensor | None = None,
|
289 |
+
entity_ids: torch.Tensor | None = None,
|
290 |
+
entity_position_ids: torch.Tensor | None = None,
|
291 |
+
entity_embeds: torch.Tensor | None = None,
|
292 |
+
output_attentions: bool | None = None,
|
293 |
+
output_hidden_states: bool | None = None,
|
294 |
+
return_dict: bool | None = None,
|
295 |
+
**kwargs
|
296 |
+
):
|
297 |
+
return self._forward(
|
298 |
+
input_ids=input_ids,
|
299 |
+
attention_mask=attention_mask,
|
300 |
+
token_type_ids=token_type_ids,
|
301 |
+
position_ids=position_ids,
|
302 |
+
head_mask=head_mask,
|
303 |
+
inputs_embeds=inputs_embeds,
|
304 |
+
entity_ids=entity_ids,
|
305 |
+
entity_position_ids=entity_position_ids,
|
306 |
+
entity_embeds=entity_embeds,
|
307 |
+
output_attentions=output_attentions,
|
308 |
+
output_hidden_states=output_hidden_states,
|
309 |
+
return_dict=return_dict,
|
310 |
+
**kwargs
|
311 |
+
)
|