dariakryvosheieva commited on
Commit
8fc3a64
·
0 Parent(s):

upload last-token model

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "JinaEmbeddingsC1Model"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoModel": "modeling_jina_embeddings_c1.JinaEmbeddingsC1Model"
8
+ },
9
+ "bos_token_id": 151643,
10
+ "eos_token_id": 151643,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 896,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4864,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention"
40
+ ],
41
+ "matryoshka_dims": [
42
+ 64,
43
+ 128,
44
+ 256,
45
+ 512,
46
+ 896
47
+ ],
48
+ "max_position_embeddings": 32768,
49
+ "max_window_layers": 24,
50
+ "model_type": "qwen2",
51
+ "num_attention_heads": 14,
52
+ "num_hidden_layers": 24,
53
+ "num_key_value_heads": 2,
54
+ "prompt_names": [
55
+ "query",
56
+ "passage"
57
+ ],
58
+ "rms_norm_eps": 1e-06,
59
+ "rope_scaling": null,
60
+ "rope_theta": 1000000.0,
61
+ "sliding_window": null,
62
+ "task_names": [
63
+ "nl2code",
64
+ "qa",
65
+ "code2code",
66
+ "code2nl",
67
+ "code2completion"
68
+ ],
69
+ "tie_word_embeddings": true,
70
+ "tokenizer_class": "Qwen2TokenizerFast",
71
+ "torch_dtype": "bfloat16",
72
+ "transformers_version": "4.53.0",
73
+ "use_cache": true,
74
+ "use_sliding_window": false,
75
+ "vocab_size": 151936
76
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4be672e203959409ac76ded4387f324d17539f3456cab628bd66bd5d0439675
3
+ size 988096088
modeling_jina_embeddings_c1.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from transformers.utils import is_flash_attn_2_available
7
+ from transformers.models.qwen2 import Qwen2Model
8
+ from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
9
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
10
+
11
+
12
+ INSTRUCTION_CONFIG = {
13
+ "nl2code": {
14
+ "query": "Find the most relevant code snippet given the following query:\n",
15
+ "passage": "Candidate code snippet:\n"
16
+ },
17
+ "qa": {
18
+ "query": "Find the most relevant answer given the following question:\n",
19
+ "passage": "Candidate answer:\n"
20
+ },
21
+ "code2code": {
22
+ "query": "Find an equivalent code snippet given the following code snippet:\n",
23
+ "passage": "Candidate code snippet:\n"
24
+ },
25
+ "code2nl": {
26
+ "query": "Find the most relevant comment given the following code snippet:\n",
27
+ "passage": "Candidate comment:\n"
28
+ },
29
+ "code2completion": {
30
+ "query": "Find the most relevant completion given the following start of code snippet:\n",
31
+ "passage": "Candidate completion:\n"
32
+ }
33
+ }
34
+
35
+
36
+ def batch(iterable, n=1):
37
+ items = len(iterable)
38
+ for ndx in range(0, items, n):
39
+ yield iterable[ndx : min(ndx + n, items)]
40
+
41
+
42
+ def mean_pooling(model_output, attention_mask):
43
+ token_embeddings = model_output[0]
44
+ input_mask_expanded = (
45
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
46
+ )
47
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
48
+ input_mask_expanded.sum(1), min=1e-9
49
+ )
50
+
51
+
52
+ def last_token_pooling(model_output, attention_mask):
53
+ token_embeddings = model_output[0]
54
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
55
+ if left_padding:
56
+ return token_embeddings[:, -1]
57
+ else:
58
+ sequence_lengths = attention_mask.sum(dim=1) - 1
59
+ batch_size = token_embeddings.shape[0]
60
+ return token_embeddings[torch.arange(batch_size, device=token_embeddings.device), sequence_lengths].float()
61
+
62
+
63
+ class JinaEmbeddingsC1Model(Qwen2Model):
64
+ def __init__(self, config: Qwen2Config):
65
+ Qwen2Model.__init__(self, config)
66
+ self.instructions = INSTRUCTION_CONFIG
67
+
68
+
69
+ def forward(
70
+ self,
71
+ input_ids: torch.LongTensor,
72
+ attention_mask: torch.Tensor,
73
+ **kwargs
74
+ ) -> List[torch.Tensor]:
75
+ """
76
+ Forward pass through the model.
77
+ """
78
+ batch_model_output = super().forward(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ **kwargs
82
+ )
83
+ batch_sentence_embeddings = last_token_pooling(
84
+ batch_model_output, attention_mask
85
+ )
86
+ return batch_sentence_embeddings
87
+
88
+
89
+ def encode(
90
+ self,
91
+ sentences: List[str],
92
+ batch_size: int = 32,
93
+ max_length: int = 32768,
94
+ task: str = "nl2code",
95
+ prompt_name: str = "query",
96
+ return_numpy: bool = False,
97
+ truncate_dim: int = 896,
98
+ ) -> Union[np.ndarray, List[torch.Tensor]]:
99
+ """
100
+ Encodes a list of texts into embeddings.
101
+ Args:
102
+ sentences: list of text strings to encode
103
+ batch_size: Number of texts to process at once
104
+ max_length: Maximum token length for text processing
105
+ task: Type of retrieval task ('nl2code', 'qa', or 'code2code')
106
+ prompt_name: Type of text being encoded ('query' or 'passage')
107
+ return_numpy: Whether to return numpy arrays instead of torch tensors
108
+ truncate_dim: Dimension to truncate embeddings to (64, 128, 256, 512, or 896)
109
+ Returns:
110
+ List of text embeddings as tensors or numpy arrays
111
+ """
112
+ assert task in self.config.task_names, \
113
+ f"Invalid task: {task}. Must be one of {self.config.task_names}."
114
+ assert prompt_name in self.config.prompt_names, \
115
+ f"Invalid prompt name: {prompt_name}. Must be one of {self.config.prompt_names}."
116
+ assert truncate_dim in self.config.matryoshka_dims, \
117
+ f"Invalid embedding dimension: {truncate_dim}. Must be one of {self.config.matryoshka_dims}."
118
+
119
+ instruction = self.instructions[task][prompt_name]
120
+ sentences = [f'{instruction}{sentence}' for sentence in sentences]
121
+ embeddings = []
122
+
123
+ self.eval()
124
+
125
+ with torch.inference_mode():
126
+ for batch_of_sentences in batch(sentences, n=batch_size):
127
+ batch_encoded_input = self.tokenizer(
128
+ batch_of_sentences,
129
+ padding=True,
130
+ truncation=True,
131
+ return_tensors="pt",
132
+ max_length=max_length
133
+ ).to(self.device)
134
+
135
+ batch_sentence_embeddings = self(
136
+ **batch_encoded_input,
137
+ output_attentions=False,
138
+ return_dict=True,
139
+ max_length=max_length
140
+ )
141
+
142
+ batch_sentence_embeddings = batch_sentence_embeddings[:, :truncate_dim]
143
+ batch_sentence_embeddings = torch.nn.functional.normalize(
144
+ batch_sentence_embeddings, p=2, dim=-1
145
+ ).to("cpu")
146
+
147
+ embeddings.append(batch_sentence_embeddings)
148
+
149
+ if return_numpy:
150
+ return np.concatenate([b.numpy() for b in embeddings], axis=0)
151
+ return [t for b in embeddings for t in torch.unbind(b, dim=0)]
152
+
153
+
154
+ @classmethod
155
+ def from_pretrained(
156
+ cls,
157
+ pretrained_model_name_or_path,
158
+ *args,
159
+ **kwargs,
160
+ ):
161
+ """
162
+ Loads a pretrained model.
163
+ """
164
+ if "torch_dtype" not in kwargs:
165
+ kwargs["torch_dtype"] = "auto"
166
+
167
+ if "attn_implementation" not in kwargs:
168
+ kwargs["attn_implementation"] = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
169
+
170
+ model = super().from_pretrained(
171
+ pretrained_model_name_or_path, *args, **kwargs
172
+ )
173
+
174
+ model.tokenizer = Qwen2TokenizerFast.from_pretrained(
175
+ pretrained_model_name_or_path,
176
+ trust_remote_code=True
177
+ )
178
+
179
+ return model
180
+
vocab.json ADDED
The diff for this file is too large to render. See raw diff