dariakryvosheieva commited on
Commit
f3d0bb7
·
verified ·
1 Parent(s): 2fa9a47

Create custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +219 -0
custom_st.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple, Union, Any, Optional
2
+
3
+ import os
4
+ import json
5
+ import torch
6
+
7
+ from torch import nn
8
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
9
+ from transformers.utils import is_flash_attn_2_available
10
+
11
+
12
+ INSTRUCTION_CONFIG = {
13
+ "nl2code": {
14
+ "query": "Find the most relevant code snippet given the following query:\n",
15
+ "passage": "Candidate code snippet:\n"
16
+ },
17
+ "qa": {
18
+ "query": "Find the most relevant answer given the following question:\n",
19
+ "passage": "Candidate answer:\n"
20
+ },
21
+ "code2code": {
22
+ "query": "Find an equivalent code snippet given the following code snippet:\n",
23
+ "passage": "Candidate code snippet:\n"
24
+ },
25
+ "code2nl": {
26
+ "query": "Find the most relevant comment given the following code snippet:\n",
27
+ "passage": "Candidate comment:\n"
28
+ },
29
+ "code2completion": {
30
+ "query": "Find the most relevant completion given the following start of code snippet:\n",
31
+ "passage": "Candidate completion:\n"
32
+ }
33
+ }
34
+
35
+
36
+ def batch(iterable, n=1):
37
+ items = len(iterable)
38
+ for ndx in range(0, items, n):
39
+ yield iterable[ndx : min(ndx + n, items)]
40
+
41
+
42
+ def last_token_pooling(model_output, attention_mask):
43
+ token_embeddings = model_output[0]
44
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
45
+ if left_padding:
46
+ return token_embeddings[:, -1]
47
+ else:
48
+ sequence_lengths = attention_mask.sum(dim=1) - 1
49
+ batch_size = token_embeddings.shape[0]
50
+ return token_embeddings[torch.arange(batch_size, device=token_embeddings.device), sequence_lengths].float()
51
+
52
+
53
+ class Transformer(nn.Module):
54
+ def __init__(
55
+ self,
56
+ model_name_or_path: str,
57
+ max_seq_length: int = None,
58
+ model_args: Dict[str, Any] = None,
59
+ tokenizer_args: Dict[str, Any] = None,
60
+ config_args: Dict[str, Any] = None,
61
+ cache_dir: str = None,
62
+ do_lower_case: bool = False,
63
+ tokenizer_name_or_path: str = None,
64
+ **kwargs,
65
+ ) -> None:
66
+ super().__init__()
67
+ self.config_keys = ["max_seq_length", "do_lower_case"]
68
+ self.do_lower_case = do_lower_case
69
+ if model_args is None:
70
+ model_args = {}
71
+ if tokenizer_args is None:
72
+ tokenizer_args = {}
73
+ if config_args is None:
74
+ config_args = {}
75
+
76
+ self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
77
+
78
+ self.task_names = self.config.task_names
79
+
80
+ self.default_task = model_args.pop('default_task', None)
81
+
82
+ model_args["attn_implementation"] = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
83
+
84
+ self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
85
+
86
+ if max_seq_length is not None and "model_max_length" not in tokenizer_args:
87
+ tokenizer_args["model_max_length"] = max_seq_length
88
+ self.tokenizer = AutoTokenizer.from_pretrained(
89
+ tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
90
+ cache_dir=cache_dir,
91
+ **tokenizer_args,
92
+ )
93
+
94
+ # No max_seq_length set. Try to infer from model
95
+ if max_seq_length is None:
96
+ if (
97
+ hasattr(self.auto_model, "config")
98
+ and hasattr(self.auto_model.config, "max_position_embeddings")
99
+ and hasattr(self.tokenizer, "model_max_length")
100
+ ):
101
+ max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length)
102
+
103
+ self.max_seq_length = max_seq_length
104
+
105
+ if tokenizer_name_or_path is not None:
106
+ self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
107
+
108
+
109
+ @property
110
+ def default_task(self):
111
+ return self._default_task
112
+
113
+
114
+ @default_task.setter
115
+ def default_task(self, task: Union[None, str]):
116
+ self._validate_task(task)
117
+ self._default_task = task
118
+
119
+
120
+ def _validate_task(self, task: str):
121
+ if task and task not in self.task_names:
122
+ raise ValueError(
123
+ f"Unsupported task '{task}'. "
124
+ f"Supported tasks are: {', '.join(self.config.task_names)}."
125
+ )
126
+
127
+
128
+ def forward(
129
+ self,
130
+ features: Dict[str, torch.Tensor],
131
+ task: Optional[str] = None
132
+ ) -> Dict[str, torch.Tensor]:
133
+ """
134
+ Forward pass through the model.
135
+ """
136
+ features.pop('prompt_length', None)
137
+ output_states = self.auto_model.forward(
138
+ **features,
139
+ output_attentions=False,
140
+ return_dict=True
141
+ )
142
+ output_tokens = output_states[0]
143
+ features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
144
+ return features
145
+
146
+
147
+ def get_word_embedding_dimension(self) -> int:
148
+ return self.auto_model.config.hidden_size
149
+
150
+
151
+ def tokenize(
152
+ self,
153
+ texts: Union[List[str], List[dict], List[Tuple[str, str]]],
154
+ padding: Union[str, bool] = True
155
+ ) -> Dict[str, torch.Tensor]:
156
+ """Tokenizes a text and maps tokens to token-ids"""
157
+ output = {}
158
+ if isinstance(texts[0], str):
159
+ to_tokenize = [texts]
160
+ elif isinstance(texts[0], dict):
161
+ to_tokenize = []
162
+ output["text_keys"] = []
163
+ for lookup in texts:
164
+ text_key, text = next(iter(lookup.items()))
165
+ to_tokenize.append(text)
166
+ output["text_keys"].append(text_key)
167
+ to_tokenize = [to_tokenize]
168
+ else:
169
+ batch1, batch2 = [], []
170
+ for text_tuple in texts:
171
+ batch1.append(text_tuple[0])
172
+ batch2.append(text_tuple[1])
173
+ to_tokenize = [batch1, batch2]
174
+
175
+ # strip
176
+ to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]
177
+
178
+ # Lowercase
179
+ if self.do_lower_case:
180
+ to_tokenize = [[s.lower() for s in col] for col in to_tokenize]
181
+
182
+ output.update(
183
+ self.tokenizer(
184
+ *to_tokenize,
185
+ padding=padding,
186
+ truncation=True,
187
+ return_tensors="pt",
188
+ max_length=self.max_seq_length,
189
+ )
190
+ )
191
+ return output
192
+
193
+
194
+ def get_config_dict(self) -> Dict[str, Any]:
195
+ return {key: self.__dict__[key] for key in self.config_keys}
196
+
197
+
198
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
199
+ self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
200
+ self.tokenizer.save_pretrained(output_path)
201
+
202
+ with open(os.path.join(output_path, "sentence_transformer_config.json"), "w") as fOut:
203
+ json.dump(self.get_config_dict(), fOut, indent=2)
204
+
205
+
206
+ @classmethod
207
+ def load(cls, input_path: str) -> "Transformer":
208
+ config_name = "sentence_transformer_config.json"
209
+ stransformer_config_path = os.path.join(input_path, config_name)
210
+ with open(stransformer_config_path) as fIn:
211
+ config = json.load(fIn)
212
+ # Don't allow configs to set trust_remote_code
213
+ if "model_args" in config and "trust_remote_code" in config["model_args"]:
214
+ config["model_args"].pop("trust_remote_code")
215
+ if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
216
+ config["tokenizer_args"].pop("trust_remote_code")
217
+ if "config_args" in config and "trust_remote_code" in config["config_args"]:
218
+ config["config_args"].pop("trust_remote_code")
219
+ return cls(model_name_or_path=input_path, **config)