|
import json |
|
import onnxruntime as rt |
|
import transformers |
|
from qdrant_client import QdrantClient, models |
|
import queue |
|
from threading import Thread, Lock |
|
import time |
|
from pyatomix import AtomicInt |
|
|
|
|
|
TOKENIZER_PATH = "." |
|
ORIG_MODEL_PATH = "model_uint8.onnx" |
|
ORIG_DATATYPE = models.Datatype.FLOAT32 |
|
ORIG_COLLECTION_NAME = "baseline" |
|
COMPARE_MODEL_PATH = "snowflake2_m_uint8.onnx" |
|
COMPARE_DATATYPE = models.Datatype.UINT8 |
|
COMPARE_COLLECTION_NAME = "compare" |
|
EMBEDDING_DIM = 768 |
|
STAT_RANGES = [ |
|
10, |
|
20, |
|
50, |
|
] |
|
STATS = {} |
|
STAT_LOCK = Lock() |
|
BATCH_SIZE = 1000 |
|
THREADS = 8 |
|
|
|
CLIENT_URL = "http://127.0.0.1" |
|
CLIENT_PORT = 6333 |
|
CLIENT_GRPC_PORT = 6334 |
|
CLIENT_USE_GRPC = True |
|
FINISHED = AtomicInt(0) |
|
|
|
|
|
def collect_tokens() -> list[str] | None: |
|
print("Attempting to grab tokens from tokenizer...") |
|
with open(f"{TOKENIZER_PATH}/tokenizer.json", "r") as f: |
|
t = f.read() |
|
j = json.loads(t) |
|
v = j["model"]["vocab"] |
|
toks = [x[0] for x in v] |
|
print(f"Found {len(toks)} tokens.") |
|
return toks |
|
|
|
|
|
def init_worker(q: queue.Queue, model_path: str, collection_name: str): |
|
try: |
|
session = rt.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
|
except Exception as e: |
|
print(f"Error loading ONNX model: {e}") |
|
return |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) |
|
client = QdrantClient( |
|
url=CLIENT_URL, |
|
port=CLIENT_PORT, |
|
grpc_port=CLIENT_GRPC_PORT, |
|
prefer_grpc=CLIENT_USE_GRPC, |
|
) |
|
global FINISHED |
|
while True: |
|
try: |
|
chunk = q.get(False) |
|
except queue.Empty: |
|
return |
|
batch = [] |
|
for c in chunk: |
|
FINISHED += 1 |
|
|
|
enc = tokenizer(c[1]) |
|
embeddings = session.run( |
|
None, |
|
{ |
|
"input_ids": [enc.input_ids], |
|
"attention_mask": [enc.attention_mask], |
|
}, |
|
) |
|
batch.append( |
|
models.PointStruct(id=c[0], vector={"dense": embeddings[1][0]}) |
|
) |
|
client.batch_update_points( |
|
collection_name=collection_name, |
|
update_operations=[models.UpsertOperation(upsert=models.PointsList(points=batch))], |
|
wait=False, |
|
) |
|
|
|
|
|
def init_collection(collection_name: str, model_path: str, datatype: models.Datatype) -> bool: |
|
client = QdrantClient( |
|
url=CLIENT_URL, |
|
port=CLIENT_PORT, |
|
grpc_port=CLIENT_GRPC_PORT, |
|
prefer_grpc=CLIENT_USE_GRPC, |
|
) |
|
if client.collection_exists(collection_name): |
|
info = client.get_collection(collection_name) |
|
print(f"Collection '{collection_name}' already exists, skipping init.") |
|
print(f"{info.points_count} points in collection.") |
|
return True |
|
res = client.create_collection( |
|
collection_name=collection_name, |
|
vectors_config={ |
|
"dense": models.VectorParams( |
|
size=EMBEDDING_DIM, |
|
distance=models.Distance.COSINE, |
|
on_disk=False, |
|
datatype=datatype, |
|
), |
|
}, |
|
hnsw_config=models.HnswConfigDiff(m=0), |
|
on_disk_payload=False, |
|
) |
|
if not res: |
|
print(f"Error creating collection.") |
|
return False |
|
else: |
|
print("Collection created.") |
|
toks = collect_tokens() |
|
FINISHED.store(0) |
|
if toks: |
|
ids = [x for x in range(len(toks))] |
|
|
|
pairs = list(zip(ids, toks)) |
|
|
|
chunks = [pairs[i : i + BATCH_SIZE] for i in range(0, len(pairs), BATCH_SIZE)] |
|
q = queue.Queue() |
|
for c in chunks: |
|
q.put(c) |
|
for _ in range(THREADS): |
|
t = Thread(target=init_worker, args=[q, model_path, collection_name]) |
|
t.start() |
|
count = 0 |
|
while FINISHED.load() < len(toks): |
|
time.sleep(0.5) |
|
count += 1 |
|
if count == 20: |
|
print(f"approximately {q.qsize() * BATCH_SIZE} items left in queue...") |
|
count = 0 |
|
print(f"Done with collection init, {len(toks)} tokens upserted.") |
|
|
|
client.update_collection(collection_name=collection_name, hnsw_config=models.HnswConfigDiff(m=16)) |
|
return True |
|
else: |
|
print("Failed to grab tokens from tokenizer.") |
|
return False |
|
|
|
|
|
def count_mismatches(list1, list2) -> int: |
|
count = 0 |
|
assert len(list1) == len(list2) |
|
for i in range(len(list1)): |
|
if list1[i] != list2[i]: |
|
count += 1 |
|
return count |
|
|
|
|
|
def score_results( |
|
list1: list, |
|
list2: list, |
|
): |
|
assert len(list1) == len(list2) |
|
global STATS |
|
for x in STAT_RANGES: |
|
with STAT_LOCK: |
|
|
|
d = STATS.get(x) |
|
if d is None: |
|
d = { |
|
"exact": AtomicInt(0), |
|
"off_by_1": AtomicInt(0), |
|
"off_by_2": AtomicInt(0), |
|
"off_by_3": AtomicInt(0), |
|
"off_by_4": AtomicInt(0), |
|
"off_by_5": AtomicInt(0), |
|
"missing": AtomicInt(0), |
|
} |
|
STATS[x] = d |
|
for i in range(x): |
|
if list1[i] == list2[i]: |
|
d["exact"] += 1 |
|
else: |
|
if list1[i] in list2: |
|
i2 = list2.index(list1[i]) |
|
val = abs(i2 - i) |
|
if val == 1: |
|
d["off_by_1"] += 1 |
|
elif val == 2: |
|
d["off_by_2"] += 1 |
|
elif val == 3: |
|
d["off_by_3"] += 1 |
|
elif val == 4: |
|
d["off_by_4"] += 1 |
|
else: |
|
d["off_by_5"] += 1 |
|
else: |
|
d["missing"] += 1 |
|
|
|
|
|
def main_worker(q: queue.Queue, limit: int): |
|
global FINISHED |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) |
|
orig_session = rt.InferenceSession(ORIG_MODEL_PATH, providers=["CPUExecutionProvider"]) |
|
compare_session = rt.InferenceSession(COMPARE_MODEL_PATH, providers=["CPUExecutionProvider"]) |
|
client = QdrantClient( |
|
url=CLIENT_URL, |
|
port=CLIENT_PORT, |
|
grpc_port=CLIENT_GRPC_PORT, |
|
prefer_grpc=CLIENT_USE_GRPC, |
|
) |
|
while True: |
|
try: |
|
chunk = q.get(False) |
|
except queue.Empty: |
|
return |
|
|
|
for c in chunk: |
|
enc = tokenizer(c) |
|
oe = orig_session.run( |
|
None, |
|
{"input_ids": [enc.input_ids], "attention_mask": [enc.attention_mask]}, |
|
) |
|
ce = compare_session.run( |
|
None, |
|
{"input_ids": [enc.input_ids], "attention_mask": [enc.attention_mask]}, |
|
) |
|
oresult = client.query_points( |
|
collection_name=ORIG_COLLECTION_NAME, |
|
using="dense", |
|
query=oe[1][0], |
|
limit=limit + 5, |
|
) |
|
cresult = client.query_points( |
|
collection_name=COMPARE_COLLECTION_NAME, |
|
using="dense", |
|
query=ce[1][0], |
|
limit=limit + 5, |
|
) |
|
oids = [p.id for p in oresult.points] |
|
cids = [p.id for p in cresult.points] |
|
score_results( |
|
oids, |
|
cids, |
|
) |
|
FINISHED += 1 |
|
|
|
|
|
def main(): |
|
if not init_collection(ORIG_COLLECTION_NAME, ORIG_MODEL_PATH, ORIG_DATATYPE): |
|
print("Failed to initialize original model values, exiting.") |
|
return |
|
if not init_collection(COMPARE_COLLECTION_NAME, COMPARE_MODEL_PATH, COMPARE_DATATYPE): |
|
print("Failed to initialize secondary model values, exiting.") |
|
return |
|
toks = collect_tokens() |
|
limit = 0 |
|
for x in STAT_RANGES: |
|
if x > limit: |
|
limit = x |
|
FINISHED.store(0) |
|
if toks: |
|
chunks = [toks[i : i + BATCH_SIZE] for i in range(0, len(toks), BATCH_SIZE)] |
|
q = queue.Queue() |
|
for c in chunks: |
|
q.put(c) |
|
print("Starting analysis.") |
|
for _ in range(THREADS): |
|
t = Thread( |
|
target=main_worker, |
|
args=[q, limit], |
|
) |
|
t.start() |
|
count = 0 |
|
while FINISHED.load() < len(toks): |
|
time.sleep(0.5) |
|
count += 1 |
|
if count == 20: |
|
print(f"approximately {q.qsize() * BATCH_SIZE} items left in queue...") |
|
count = 0 |
|
print(f"Done with analysis.") |
|
with STAT_LOCK: |
|
for k, v in STATS.items(): |
|
print(f"Stats for top {k} query results across entire token range:") |
|
print(f"exact : {(float(v["exact"].load()) / (len(toks) * k)) * 100:.2f}%") |
|
print(f"off by 1 : {(float(v["off_by_1"].load()) / (len(toks) * k)) * 100:.2f}%") |
|
print(f"off by 2 : {(float(v["off_by_2"].load()) / (len(toks) * k)) * 100:.2f}%") |
|
print(f"off by 3 : {(float(v["off_by_3"].load()) / (len(toks) * k)) * 100:.2f}%") |
|
print(f"off by 4 : {(float(v["off_by_4"].load()) / (len(toks) * k)) * 100:.2f}%") |
|
print(f"off by 5+: {(float(v["off_by_5"].load()) / (len(toks) * k)) * 100:.2f}%") |
|
print(f"missing : {(float(v["missing"].load()) / (len(toks) * k)) * 100:.2f}%\n") |
|
|
|
|
|
main() |
|
|