ikuyamada commited on
Commit
3f94a29
·
verified ·
1 Parent(s): cf413c2

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +66 -6
modeling.py CHANGED
@@ -155,7 +155,9 @@ class EntityFusionLayer(nn.Module):
155
 
156
 
157
  class KPRMixin:
158
- def _forward(self, **inputs: dict[str, Tensor]) -> tuple[Tensor] | tuple[Tensor, Tensor] | ModelOutput:
 
 
159
  return_dict = inputs.pop("return_dict", True)
160
 
161
  if self.training:
@@ -185,7 +187,7 @@ class KPRMixin:
185
  else:
186
  return (sentence_embeddings,)
187
 
188
- def encode(self, **inputs: dict[str, Tensor]) -> Tensor:
189
  entity_ids = inputs.pop("entity_ids", None)
190
  entity_position_ids = inputs.pop("entity_position_ids", None)
191
  entity_embeds = inputs.pop("entity_embeds", None)
@@ -231,8 +233,37 @@ class KPRModelForBert(BertPreTrainedModel, KPRMixin):
231
 
232
  self.post_init()
233
 
234
- def forward(self, *args, **kwargs):
235
- return self._forward(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
 
238
  class KPRModelForXLMRoberta(XLMRobertaPreTrainedModel, KPRMixin):
@@ -247,5 +278,34 @@ class KPRModelForXLMRoberta(XLMRobertaPreTrainedModel, KPRMixin):
247
 
248
  self.post_init()
249
 
250
- def forward(self, *args, **kwargs):
251
- return self._forward(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
 
157
  class KPRMixin:
158
+ def _forward(
159
+ self, **inputs: Tensor | bool | None | dict[str, Tensor | bool | None]
160
+ ) -> tuple[Tensor] | tuple[Tensor, Tensor] | ModelOutput:
161
  return_dict = inputs.pop("return_dict", True)
162
 
163
  if self.training:
 
187
  else:
188
  return (sentence_embeddings,)
189
 
190
+ def encode(self, **inputs: Tensor | bool | None) -> Tensor:
191
  entity_ids = inputs.pop("entity_ids", None)
192
  entity_position_ids = inputs.pop("entity_position_ids", None)
193
  entity_embeds = inputs.pop("entity_embeds", None)
 
233
 
234
  self.post_init()
235
 
236
+ def forward(
237
+ self,
238
+ input_ids: torch.Tensor | None = None,
239
+ attention_mask: torch.Tensor | None = None,
240
+ token_type_ids: torch.Tensor | None = None,
241
+ position_ids: torch.Tensor | None = None,
242
+ head_mask: torch.Tensor | None = None,
243
+ inputs_embeds: torch.Tensor | None = None,
244
+ entity_ids: torch.Tensor | None = None,
245
+ entity_position_ids: torch.Tensor | None = None,
246
+ entity_embeds: torch.Tensor | None = None,
247
+ output_attentions: bool | None = None,
248
+ output_hidden_states: bool | None = None,
249
+ return_dict: bool | None = None,
250
+ **kwargs
251
+ ):
252
+ return self._forward(
253
+ input_ids=input_ids,
254
+ attention_mask=attention_mask,
255
+ token_type_ids=token_type_ids,
256
+ position_ids=position_ids,
257
+ head_mask=head_mask,
258
+ inputs_embeds=inputs_embeds,
259
+ entity_ids=entity_ids,
260
+ entity_position_ids=entity_position_ids,
261
+ entity_embeds=entity_embeds,
262
+ output_attentions=output_attentions,
263
+ output_hidden_states=output_hidden_states,
264
+ return_dict=return_dict,
265
+ **kwargs
266
+ )
267
 
268
 
269
  class KPRModelForXLMRoberta(XLMRobertaPreTrainedModel, KPRMixin):
 
278
 
279
  self.post_init()
280
 
281
+ def forward(
282
+ self,
283
+ input_ids: torch.Tensor | None = None,
284
+ attention_mask: torch.Tensor | None = None,
285
+ token_type_ids: torch.Tensor | None = None,
286
+ position_ids: torch.Tensor | None = None,
287
+ head_mask: torch.Tensor | None = None,
288
+ inputs_embeds: torch.Tensor | None = None,
289
+ entity_ids: torch.Tensor | None = None,
290
+ entity_position_ids: torch.Tensor | None = None,
291
+ entity_embeds: torch.Tensor | None = None,
292
+ output_attentions: bool | None = None,
293
+ output_hidden_states: bool | None = None,
294
+ return_dict: bool | None = None,
295
+ **kwargs
296
+ ):
297
+ return self._forward(
298
+ input_ids=input_ids,
299
+ attention_mask=attention_mask,
300
+ token_type_ids=token_type_ids,
301
+ position_ids=position_ids,
302
+ head_mask=head_mask,
303
+ inputs_embeds=inputs_embeds,
304
+ entity_ids=entity_ids,
305
+ entity_position_ids=entity_position_ids,
306
+ entity_embeds=entity_embeds,
307
+ output_attentions=output_attentions,
308
+ output_hidden_states=output_hidden_states,
309
+ return_dict=return_dict,
310
+ **kwargs
311
+ )