electroglyph commited on
Commit
6f99c6d
·
verified ·
1 Parent(s): bddaca5

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +127 -5
  2. benchmark.py +292 -0
README.md CHANGED
@@ -87,10 +87,6 @@ language:
87
  - yo
88
  - zh
89
  ---
90
- # Accuracy
91
-
92
- Not sure on accuracy quite yet, will update soon. After I confirm this is working well (preliminary results suggest it's good), I can try a version which combines normalization + quantization for the `token_embeddings` output.
93
-
94
  # snowflake2_m_uint8
95
 
96
  This is a slightly modified version of the uint8 quantized ONNX model from https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0
@@ -113,6 +109,130 @@ Here's what the new graph in this model looks like:
113
 
114
  ![modified model graph](./quant_model.png)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # Example inference code
117
 
118
  ```python
@@ -120,7 +240,7 @@ import onnxruntime as rt
120
  import transformers
121
 
122
  tokenizer = transformers.AutoTokenizer.from_pretrained(
123
- "snowflake2_m_uint8" # path to the folder for this model goes here
124
  )
125
  session = rt.InferenceSession(
126
  "snowflake2_m_uint8.onnx", providers=["CPUExecutionProvider"]
@@ -131,4 +251,6 @@ embeddings = session.run(
131
  None, {"input_ids": [enc.input_ids], "attention_mask": [enc.attention_mask]}
132
  )
133
  e = embeddings[1][0] # this is the output tensor for sentence_embedding, it is uint8 array of size 768
 
 
134
  ```
 
87
  - yo
88
  - zh
89
  ---
 
 
 
 
90
  # snowflake2_m_uint8
91
 
92
  This is a slightly modified version of the uint8 quantized ONNX model from https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0
 
109
 
110
  ![modified model graph](./quant_model.png)
111
 
112
+ # Benchmark
113
+
114
+ I don't have an NVIDIA GPU, so running some of the MTEB benchmarks is a bit of a chore.
115
+
116
+ Instead I created this little benchmark which I'll now explain.
117
+
118
+ Here's how it works:
119
+
120
+ 1) I generate embeddings for each token in this model. I do this with the original model, and my quantized output model
121
+
122
+ 2) I upsert these embeddings into Qdrant DB, with ID == token index
123
+
124
+ 3) I compare the models by querying a token on one model, then the other model, and seeing how different the results are
125
+
126
+ For instance:
127
+
128
+ When I query the embedding for token 0, limit=10 using `model_uint8.onnx` I get the top result here.
129
+ Same query for this model is the bottom result.
130
+
131
+ [0, 181513, 3309, 97636, 6, 104615, 95353, 124967, 115375, 87124]
132
+ [0, 181513, 3309, 95353, 6, 104615, 97636, 124967, 115375, 87124]
133
+
134
+ The results are close, but in my model the results in position 4 & 7 have been swapped.
135
+
136
+ My benchmark here is measuring how often this happens.
137
+
138
+ The code for reproducing this benchmark is located in this repo in `benchmark.py`
139
+
140
+ ...
141
+
142
+ Here are the results for [model_uint8.onnx](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0/blob/main/onnx/model_uint8.onnx) vs my model here. Exact means the same tokens were in the same position. 'off by 1' means the correct token was in the results, but it was in a position 1 away from the original position. 'missing' means that a token which was present in the original query wasn't found in the results for my model.
143
+
144
+ Note that discrepancies here don't necessarily mean *wrong* results, just *different* results. The best way to see differences is to test directly on your own data and see if the results are to your liking.
145
+
146
+ ```
147
+ Stats for top 10 query results across entire token range:
148
+ exact : 76.18%
149
+ off by 1 : 19.77%
150
+ off by 2 : 2.72%
151
+ off by 3 : 0.54%
152
+ off by 4 : 0.12%
153
+ off by 5+: 0.04%
154
+ missing : 0.63%
155
+
156
+ Stats for top 20 query results across entire token range:
157
+ exact : 65.86%
158
+ off by 1 : 25.00%
159
+ off by 2 : 5.87%
160
+ off by 3 : 1.68%
161
+ off by 4 : 0.53%
162
+ off by 5+: 0.27%
163
+ missing : 0.78%
164
+
165
+ Stats for top 50 query results across entire token range:
166
+ exact : 48.54%
167
+ off by 1 : 29.09%
168
+ off by 2 : 11.35%
169
+ off by 3 : 5.02%
170
+ off by 4 : 2.38%
171
+ off by 5+: 2.36%
172
+ missing : 1.26%
173
+ ```
174
+
175
+ Here are the results for [model_fp16.onnx](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0/blob/main/onnx/model_fp16.onnx) vs [model_uint8.onnx](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0/blob/main/onnx/model_uint8.onnx):
176
+
177
+ ```
178
+ Stats for top 10 query results across entire token range:
179
+ exact : 20.54%
180
+ off by 1 : 13.79%
181
+ off by 2 : 8.55%
182
+ off by 3 : 6.37%
183
+ off by 4 : 4.87%
184
+ off by 5+: 31.53%
185
+ missing : 14.34%
186
+
187
+ Stats for top 20 query results across entire token range:
188
+ exact : 11.58%
189
+ off by 1 : 9.46%
190
+ off by 2 : 6.76%
191
+ off by 3 : 5.58%
192
+ off by 4 : 4.70%
193
+ off by 5+: 38.80%
194
+ missing : 23.12%
195
+
196
+ Stats for top 50 query results across entire token range:
197
+ exact : 5.34%
198
+ off by 1 : 5.18%
199
+ off by 2 : 4.09%
200
+ off by 3 : 3.60%
201
+ off by 4 : 3.22%
202
+ off by 5+: 36.17%
203
+ missing : 42.38%
204
+ ```
205
+
206
+ Here are the results for [model.onnx](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0/blob/main/onnx/model.onnx) vs [model_fp16.onnx](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0/blob/main/onnx/model_fp16.onnx):
207
+
208
+ ```
209
+ Stats for top 10 query results across entire token range:
210
+ exact : 18.12%
211
+ off by 1 : 11.80%
212
+ off by 2 : 7.41%
213
+ off by 3 : 5.65%
214
+ off by 4 : 4.45%
215
+ off by 5+: 32.29%
216
+ missing : 20.28%
217
+
218
+ Stats for top 20 query results across entire token range:
219
+ exact : 10.08%
220
+ off by 1 : 7.93%
221
+ off by 2 : 5.70%
222
+ off by 3 : 4.77%
223
+ off by 4 : 4.11%
224
+ off by 5+: 37.46%
225
+ missing : 29.96%
226
+
227
+ Stats for top 50 query results across entire token range:
228
+ exact : 4.59%
229
+ off by 1 : 4.28%
230
+ off by 2 : 3.39%
231
+ off by 3 : 3.00%
232
+ off by 4 : 2.73%
233
+ off by 5+: 33.45%
234
+ missing : 48.58%
235
+ ```
236
  # Example inference code
237
 
238
  ```python
 
240
  import transformers
241
 
242
  tokenizer = transformers.AutoTokenizer.from_pretrained(
243
+ "." # path to wherever this model is located
244
  )
245
  session = rt.InferenceSession(
246
  "snowflake2_m_uint8.onnx", providers=["CPUExecutionProvider"]
 
251
  None, {"input_ids": [enc.input_ids], "attention_mask": [enc.attention_mask]}
252
  )
253
  e = embeddings[1][0] # this is the output tensor for sentence_embedding, it is uint8 array of size 768
254
+ # alternatively, if you change the first argument of session.run to ['sentence_embedding']
255
+ # then you would get the results from embeddings[0][0]
256
  ```
benchmark.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import onnxruntime as rt
3
+ import transformers
4
+ from qdrant_client import QdrantClient, models
5
+ import queue
6
+ from threading import Thread, Lock
7
+ import time
8
+ from pyatomix import AtomicInt
9
+
10
+ # adjust these settings as needed
11
+ TOKENIZER_PATH = "."
12
+ ORIG_MODEL_PATH = "model_uint8.onnx"
13
+ ORIG_DATATYPE = models.Datatype.FLOAT32
14
+ ORIG_COLLECTION_NAME = "baseline"
15
+ COMPARE_MODEL_PATH = "snowflake2_m_uint8.onnx"
16
+ COMPARE_DATATYPE = models.Datatype.UINT8
17
+ COMPARE_COLLECTION_NAME = "compare"
18
+ EMBEDDING_DIM = 768 # size of the model output
19
+ STAT_RANGES = [
20
+ 10,
21
+ 20,
22
+ 50,
23
+ ] # stats will be calculated for each range: top 10, top 20, etc.
24
+ STATS = {}
25
+ STAT_LOCK = Lock()
26
+ BATCH_SIZE = 1000 # this many token/id pairs will be processed at a time
27
+ THREADS = 8 # number of threads to use
28
+ # Qdrant client settings here
29
+ CLIENT_URL = "http://127.0.0.1"
30
+ CLIENT_PORT = 6333
31
+ CLIENT_GRPC_PORT = 6334
32
+ CLIENT_USE_GRPC = True
33
+ FINISHED = AtomicInt(0)
34
+
35
+
36
+ def collect_tokens() -> list[str] | None:
37
+ print("Attempting to grab tokens from tokenizer...")
38
+ with open(f"{TOKENIZER_PATH}/tokenizer.json", "r") as f:
39
+ t = f.read()
40
+ j = json.loads(t)
41
+ v = j["model"]["vocab"]
42
+ toks = [x[0] for x in v]
43
+ print(f"Found {len(toks)} tokens.")
44
+ return toks
45
+
46
+
47
+ def init_worker(q: queue.Queue, model_path: str, collection_name: str):
48
+ try:
49
+ session = rt.InferenceSession(model_path, providers=["CPUExecutionProvider"])
50
+ except Exception as e:
51
+ print(f"Error loading ONNX model: {e}")
52
+ return
53
+ tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)
54
+ client = QdrantClient(
55
+ url=CLIENT_URL,
56
+ port=CLIENT_PORT,
57
+ grpc_port=CLIENT_GRPC_PORT,
58
+ prefer_grpc=CLIENT_USE_GRPC,
59
+ )
60
+ global FINISHED
61
+ while True:
62
+ try:
63
+ chunk = q.get(False)
64
+ except queue.Empty:
65
+ return
66
+ batch = []
67
+ for c in chunk:
68
+ FINISHED += 1
69
+ # c[0] == id, c[1] == token, we want id to always be associated with the same token across models
70
+ enc = tokenizer(c[1]) # this could've been batched...
71
+ embeddings = session.run(
72
+ None,
73
+ {
74
+ "input_ids": [enc.input_ids],
75
+ "attention_mask": [enc.attention_mask],
76
+ },
77
+ )
78
+ batch.append( # [1][0] == sentence_embedding
79
+ models.PointStruct(id=c[0], vector={"dense": embeddings[1][0]})
80
+ )
81
+ client.batch_update_points(
82
+ collection_name=collection_name,
83
+ update_operations=[models.UpsertOperation(upsert=models.PointsList(points=batch))],
84
+ wait=False,
85
+ )
86
+
87
+
88
+ def init_collection(collection_name: str, model_path: str, datatype: models.Datatype) -> bool:
89
+ client = QdrantClient(
90
+ url=CLIENT_URL,
91
+ port=CLIENT_PORT,
92
+ grpc_port=CLIENT_GRPC_PORT,
93
+ prefer_grpc=CLIENT_USE_GRPC,
94
+ )
95
+ if client.collection_exists(collection_name):
96
+ info = client.get_collection(collection_name)
97
+ print(f"Collection '{collection_name}' already exists, skipping init.")
98
+ print(f"{info.points_count} points in collection.")
99
+ return True
100
+ res = client.create_collection(
101
+ collection_name=collection_name,
102
+ vectors_config={
103
+ "dense": models.VectorParams(
104
+ size=EMBEDDING_DIM,
105
+ distance=models.Distance.COSINE,
106
+ on_disk=False,
107
+ datatype=datatype,
108
+ ),
109
+ },
110
+ hnsw_config=models.HnswConfigDiff(m=0), # no index
111
+ on_disk_payload=False,
112
+ )
113
+ if not res:
114
+ print(f"Error creating collection.")
115
+ return False
116
+ else:
117
+ print("Collection created.")
118
+ toks = collect_tokens()
119
+ FINISHED.store(0)
120
+ if toks:
121
+ ids = [x for x in range(len(toks))]
122
+ # align Qdrant IDs with the token for later analysis
123
+ pairs = list(zip(ids, toks))
124
+ # lists of (Qdrant ID, token)
125
+ chunks = [pairs[i : i + BATCH_SIZE] for i in range(0, len(pairs), BATCH_SIZE)]
126
+ q = queue.Queue()
127
+ for c in chunks:
128
+ q.put(c)
129
+ for _ in range(THREADS):
130
+ t = Thread(target=init_worker, args=[q, model_path, collection_name])
131
+ t.start()
132
+ count = 0
133
+ while FINISHED.load() < len(toks):
134
+ time.sleep(0.5)
135
+ count += 1
136
+ if count == 20: # update every 10 seconds or so
137
+ print(f"approximately {q.qsize() * BATCH_SIZE} items left in queue...")
138
+ count = 0
139
+ print(f"Done with collection init, {len(toks)} tokens upserted.")
140
+ # enable indexing
141
+ client.update_collection(collection_name=collection_name, hnsw_config=models.HnswConfigDiff(m=16))
142
+ return True
143
+ else:
144
+ print("Failed to grab tokens from tokenizer.")
145
+ return False
146
+
147
+
148
+ def count_mismatches(list1, list2) -> int:
149
+ count = 0
150
+ assert len(list1) == len(list2)
151
+ for i in range(len(list1)):
152
+ if list1[i] != list2[i]:
153
+ count += 1
154
+ return count
155
+
156
+
157
+ def score_results(
158
+ list1: list,
159
+ list2: list,
160
+ ):
161
+ assert len(list1) == len(list2)
162
+ global STATS
163
+ for x in STAT_RANGES:
164
+ with STAT_LOCK:
165
+ # STATS = { range, {"exact": AtomicInt, ... }}
166
+ d = STATS.get(x)
167
+ if d is None:
168
+ d = {
169
+ "exact": AtomicInt(0),
170
+ "off_by_1": AtomicInt(0),
171
+ "off_by_2": AtomicInt(0),
172
+ "off_by_3": AtomicInt(0),
173
+ "off_by_4": AtomicInt(0),
174
+ "off_by_5": AtomicInt(0),
175
+ "missing": AtomicInt(0),
176
+ }
177
+ STATS[x] = d
178
+ for i in range(x):
179
+ if list1[i] == list2[i]:
180
+ d["exact"] += 1
181
+ else:
182
+ if list1[i] in list2:
183
+ i2 = list2.index(list1[i])
184
+ val = abs(i2 - i)
185
+ if val == 1:
186
+ d["off_by_1"] += 1
187
+ elif val == 2:
188
+ d["off_by_2"] += 1
189
+ elif val == 3:
190
+ d["off_by_3"] += 1
191
+ elif val == 4:
192
+ d["off_by_4"] += 1
193
+ else:
194
+ d["off_by_5"] += 1
195
+ else:
196
+ d["missing"] += 1
197
+
198
+
199
+ def main_worker(q: queue.Queue, limit: int):
200
+ global FINISHED
201
+ tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)
202
+ orig_session = rt.InferenceSession(ORIG_MODEL_PATH, providers=["CPUExecutionProvider"])
203
+ compare_session = rt.InferenceSession(COMPARE_MODEL_PATH, providers=["CPUExecutionProvider"])
204
+ client = QdrantClient(
205
+ url=CLIENT_URL,
206
+ port=CLIENT_PORT,
207
+ grpc_port=CLIENT_GRPC_PORT,
208
+ prefer_grpc=CLIENT_USE_GRPC,
209
+ )
210
+ while True:
211
+ try:
212
+ chunk = q.get(False)
213
+ except queue.Empty:
214
+ return
215
+ # c[0] == id, c[1] == token, we want id to always be associated with the same token across models
216
+ for c in chunk:
217
+ enc = tokenizer(c)
218
+ oe = orig_session.run(
219
+ None,
220
+ {"input_ids": [enc.input_ids], "attention_mask": [enc.attention_mask]},
221
+ )
222
+ ce = compare_session.run(
223
+ None,
224
+ {"input_ids": [enc.input_ids], "attention_mask": [enc.attention_mask]},
225
+ )
226
+ oresult = client.query_points(
227
+ collection_name=ORIG_COLLECTION_NAME,
228
+ using="dense",
229
+ query=oe[1][0],
230
+ limit=limit + 5, # for our scoring metric we want to look slightly past the end
231
+ )
232
+ cresult = client.query_points(
233
+ collection_name=COMPARE_COLLECTION_NAME,
234
+ using="dense",
235
+ query=ce[1][0],
236
+ limit=limit + 5,
237
+ )
238
+ oids = [p.id for p in oresult.points]
239
+ cids = [p.id for p in cresult.points]
240
+ score_results(
241
+ oids,
242
+ cids,
243
+ )
244
+ FINISHED += 1
245
+
246
+
247
+ def main():
248
+ if not init_collection(ORIG_COLLECTION_NAME, ORIG_MODEL_PATH, ORIG_DATATYPE):
249
+ print("Failed to initialize original model values, exiting.")
250
+ return
251
+ if not init_collection(COMPARE_COLLECTION_NAME, COMPARE_MODEL_PATH, COMPARE_DATATYPE):
252
+ print("Failed to initialize secondary model values, exiting.")
253
+ return
254
+ toks = collect_tokens()
255
+ limit = 0
256
+ for x in STAT_RANGES:
257
+ if x > limit:
258
+ limit = x
259
+ FINISHED.store(0)
260
+ if toks:
261
+ chunks = [toks[i : i + BATCH_SIZE] for i in range(0, len(toks), BATCH_SIZE)]
262
+ q = queue.Queue()
263
+ for c in chunks:
264
+ q.put(c)
265
+ print("Starting analysis.")
266
+ for _ in range(THREADS):
267
+ t = Thread(
268
+ target=main_worker,
269
+ args=[q, limit],
270
+ )
271
+ t.start()
272
+ count = 0
273
+ while FINISHED.load() < len(toks):
274
+ time.sleep(0.5)
275
+ count += 1
276
+ if count == 20: # update every 10 seconds or so
277
+ print(f"approximately {q.qsize() * BATCH_SIZE} items left in queue...")
278
+ count = 0
279
+ print(f"Done with analysis.")
280
+ with STAT_LOCK:
281
+ for k, v in STATS.items():
282
+ print(f"Stats for top {k} query results across entire token range:")
283
+ print(f"exact : {(float(v["exact"].load()) / (len(toks) * k)) * 100:.2f}%")
284
+ print(f"off by 1 : {(float(v["off_by_1"].load()) / (len(toks) * k)) * 100:.2f}%")
285
+ print(f"off by 2 : {(float(v["off_by_2"].load()) / (len(toks) * k)) * 100:.2f}%")
286
+ print(f"off by 3 : {(float(v["off_by_3"].load()) / (len(toks) * k)) * 100:.2f}%")
287
+ print(f"off by 4 : {(float(v["off_by_4"].load()) / (len(toks) * k)) * 100:.2f}%")
288
+ print(f"off by 5+: {(float(v["off_by_5"].load()) / (len(toks) * k)) * 100:.2f}%")
289
+ print(f"missing : {(float(v["missing"].load()) / (len(toks) * k)) * 100:.2f}%\n")
290
+
291
+
292
+ main()