ikuyamada commited on
Commit
042da7b
·
verified ·
1 Parent(s): 2713c05

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +5 -5
modeling.py CHANGED
@@ -155,12 +155,12 @@ class EntityFusionLayer(nn.Module):
155
 
156
 
157
  class KPRMixin:
158
- def _forward(self, **inputs: Tensor | dict[str, Tensor]) -> tuple[Tensor] | tuple[Tensor, Tensor] | ModelOutput:
159
  return_dict = inputs.pop("return_dict", True)
160
 
161
  if self.training:
162
- query_embeddings = self.encode(inputs["queries"])
163
- passage_embeddings = self.encode(inputs["passages"])
164
 
165
  query_embeddings = self._dist_gather_tensor(query_embeddings)
166
  passage_embeddings = self._dist_gather_tensor(passage_embeddings)
@@ -179,13 +179,13 @@ class KPRMixin:
179
  return (loss, scores)
180
 
181
  else:
182
- sentence_embeddings = self.encode(inputs).unsqueeze(1)
183
  if return_dict:
184
  return ModelOutput(sentence_embeddings=sentence_embeddings)
185
  else:
186
  return (sentence_embeddings,)
187
 
188
- def encode(self, inputs: dict[str, Tensor]) -> Tensor:
189
  entity_ids = inputs.pop("entity_ids", None)
190
  entity_position_ids = inputs.pop("entity_position_ids", None)
191
  entity_embeds = inputs.pop("entity_embeds", None)
 
155
 
156
 
157
  class KPRMixin:
158
+ def _forward(self, **inputs: dict[str, Tensor]) -> tuple[Tensor] | tuple[Tensor, Tensor] | ModelOutput:
159
  return_dict = inputs.pop("return_dict", True)
160
 
161
  if self.training:
162
+ query_embeddings = self.encode(**inputs["queries"])
163
+ passage_embeddings = self.encode(**inputs["passages"])
164
 
165
  query_embeddings = self._dist_gather_tensor(query_embeddings)
166
  passage_embeddings = self._dist_gather_tensor(passage_embeddings)
 
179
  return (loss, scores)
180
 
181
  else:
182
+ sentence_embeddings = self.encode(**inputs).unsqueeze(1)
183
  if return_dict:
184
  return ModelOutput(sentence_embeddings=sentence_embeddings)
185
  else:
186
  return (sentence_embeddings,)
187
 
188
+ def encode(self, **inputs: dict[str, Tensor]) -> Tensor:
189
  entity_ids = inputs.pop("entity_ids", None)
190
  entity_position_ids = inputs.pop("entity_position_ids", None)
191
  entity_embeds = inputs.pop("entity_embeds", None)