Sentence Similarity
Transformers
Safetensors
English
mistral
feature-extraction
text-embedding
embeddings
information-retrieval
beir
text-classification
language-model
text-clustering
text-semantic-similarity
text-evaluation
text-reranking
Sentence Similarity
natural_questions
ms_marco
fever
hotpot_qa
mteb
custom_code
text-generation-inference
text-embeddings-inference
Update modeling_mistral_encoder.py
Browse files- modeling_mistral_encoder.py +0 -66
modeling_mistral_encoder.py
CHANGED
|
@@ -13,15 +13,6 @@ from .attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
| 13 |
|
| 14 |
logger = logging.get_logger(__name__)
|
| 15 |
|
| 16 |
-
def batch_to_device(batch, target_device: device):
|
| 17 |
-
"""
|
| 18 |
-
send a pytorch batch to a device (CPU/GPU)
|
| 19 |
-
"""
|
| 20 |
-
for key in batch:
|
| 21 |
-
if isinstance(batch[key], Tensor):
|
| 22 |
-
batch[key] = batch[key].to(target_device)
|
| 23 |
-
return batch
|
| 24 |
-
|
| 25 |
class ModifiedMistralAttention(MistralAttention):
|
| 26 |
|
| 27 |
def __init__(self, *args, **kwargs):
|
|
@@ -218,60 +209,3 @@ class MistralEncoderModel(MistralModel):
|
|
| 218 |
hidden_states=all_hidden_states,
|
| 219 |
attentions=all_self_attns,
|
| 220 |
)
|
| 221 |
-
|
| 222 |
-
def prepare_for_tokenization(self, text):
|
| 223 |
-
|
| 224 |
-
text = '[INST] ' + text.strip() + ' [/INST]'
|
| 225 |
-
# if self.pooling_mode == "eos_token":
|
| 226 |
-
# text = text.strip() + ' </s>'
|
| 227 |
-
return text
|
| 228 |
-
|
| 229 |
-
def tokenize(self, texts):
|
| 230 |
-
# return self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)
|
| 231 |
-
|
| 232 |
-
texts_2 = []
|
| 233 |
-
original_texts = []
|
| 234 |
-
for text in texts:
|
| 235 |
-
t = text.split("!@#$%^&*()")
|
| 236 |
-
texts_2.append(t[1])
|
| 237 |
-
original_texts.append("".join(t))
|
| 238 |
-
|
| 239 |
-
original = self.tokenizer(original_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)
|
| 240 |
-
embed_mask = None
|
| 241 |
-
for t_i, t in enumerate(texts_2):
|
| 242 |
-
ids = self.tokenizer([t], return_tensors='pt', padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False)
|
| 243 |
-
if embed_mask is None:
|
| 244 |
-
e_m = torch.zeros_like(original["attention_mask"][t_i])
|
| 245 |
-
if len(ids["input_ids"][0]) > 0:
|
| 246 |
-
e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
|
| 247 |
-
embed_mask = e_m.unsqueeze(0)
|
| 248 |
-
else:
|
| 249 |
-
e_m = torch.zeros_like(original["attention_mask"][t_i])
|
| 250 |
-
if len(ids["input_ids"][0]) > 0:
|
| 251 |
-
e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
|
| 252 |
-
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
|
| 253 |
-
|
| 254 |
-
original["embed_mask"] = embed_mask
|
| 255 |
-
return original
|
| 256 |
-
|
| 257 |
-
def _skip_instruction(self, sentence_feature):
|
| 258 |
-
assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape
|
| 259 |
-
sentence_feature["attention_mask"] = sentence_feature["embed_mask"]
|
| 260 |
-
|
| 261 |
-
def _encode(self, sentences_batch, device, convert_to_numpy, multiprocessing=False):
|
| 262 |
-
|
| 263 |
-
if multiprocessing:
|
| 264 |
-
rank = mp.current_process()._identity[0]
|
| 265 |
-
if device is None and torch.cuda.is_available():
|
| 266 |
-
device = f"cuda:{rank % torch.cuda.device_count()}"
|
| 267 |
-
|
| 268 |
-
self.to(device)
|
| 269 |
-
features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch])
|
| 270 |
-
features = batch_to_device(features, device)
|
| 271 |
-
|
| 272 |
-
with torch.no_grad():
|
| 273 |
-
embeddings = self.forward(features)
|
| 274 |
-
embeddings = embeddings.detach()
|
| 275 |
-
embeddings = embeddings.cpu()
|
| 276 |
-
|
| 277 |
-
return embeddings
|
|
|
|
| 13 |
|
| 14 |
logger = logging.get_logger(__name__)
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class ModifiedMistralAttention(MistralAttention):
|
| 17 |
|
| 18 |
def __init__(self, *args, **kwargs):
|
|
|
|
| 209 |
hidden_states=all_hidden_states,
|
| 210 |
attentions=all_self_attns,
|
| 211 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|