Upload model
Browse files- modeling.py +5 -5
modeling.py
CHANGED
@@ -155,12 +155,12 @@ class EntityFusionLayer(nn.Module):
|
|
155 |
|
156 |
|
157 |
class KPRMixin:
|
158 |
-
def _forward(self, **inputs:
|
159 |
return_dict = inputs.pop("return_dict", True)
|
160 |
|
161 |
if self.training:
|
162 |
-
query_embeddings = self.encode(inputs["queries"])
|
163 |
-
passage_embeddings = self.encode(inputs["passages"])
|
164 |
|
165 |
query_embeddings = self._dist_gather_tensor(query_embeddings)
|
166 |
passage_embeddings = self._dist_gather_tensor(passage_embeddings)
|
@@ -179,13 +179,13 @@ class KPRMixin:
|
|
179 |
return (loss, scores)
|
180 |
|
181 |
else:
|
182 |
-
sentence_embeddings = self.encode(inputs).unsqueeze(1)
|
183 |
if return_dict:
|
184 |
return ModelOutput(sentence_embeddings=sentence_embeddings)
|
185 |
else:
|
186 |
return (sentence_embeddings,)
|
187 |
|
188 |
-
def encode(self, inputs: dict[str, Tensor]) -> Tensor:
|
189 |
entity_ids = inputs.pop("entity_ids", None)
|
190 |
entity_position_ids = inputs.pop("entity_position_ids", None)
|
191 |
entity_embeds = inputs.pop("entity_embeds", None)
|
|
|
155 |
|
156 |
|
157 |
class KPRMixin:
|
158 |
+
def _forward(self, **inputs: dict[str, Tensor]) -> tuple[Tensor] | tuple[Tensor, Tensor] | ModelOutput:
|
159 |
return_dict = inputs.pop("return_dict", True)
|
160 |
|
161 |
if self.training:
|
162 |
+
query_embeddings = self.encode(**inputs["queries"])
|
163 |
+
passage_embeddings = self.encode(**inputs["passages"])
|
164 |
|
165 |
query_embeddings = self._dist_gather_tensor(query_embeddings)
|
166 |
passage_embeddings = self._dist_gather_tensor(passage_embeddings)
|
|
|
179 |
return (loss, scores)
|
180 |
|
181 |
else:
|
182 |
+
sentence_embeddings = self.encode(**inputs).unsqueeze(1)
|
183 |
if return_dict:
|
184 |
return ModelOutput(sentence_embeddings=sentence_embeddings)
|
185 |
else:
|
186 |
return (sentence_embeddings,)
|
187 |
|
188 |
+
def encode(self, **inputs: dict[str, Tensor]) -> Tensor:
|
189 |
entity_ids = inputs.pop("entity_ids", None)
|
190 |
entity_position_ids = inputs.pop("entity_position_ids", None)
|
191 |
entity_embeds = inputs.pop("entity_embeds", None)
|