Ali Kefia commited on
Commit
231da5b
·
1 Parent(s): 8abf5b9
Files changed (1) hide show
  1. usage.py +47 -0
usage.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from functools import cache
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import polars as pl
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from embed import embed
11
+
12
+ DATA = Path(os.environ["DATA_DIR"])
13
+
14
+ features = ["content", "meta_title", "meta_description"]
15
+
16
+
17
+ @cache
18
+ def get_model():
19
+ file_name = hf_hub_download("opale-ai/news-classifier", "model/model.pickle")
20
+ with open(file_name, "rb") as f:
21
+ return pickle.load(f)
22
+
23
+
24
+ def record_get():
25
+ df = pl.read_csv(DATA / "eval.csv")
26
+ return {col: val for col, val in zip(df.columns, df.sample().row(0))}
27
+
28
+
29
+ def record_embed(rec):
30
+ embeddings = []
31
+ for f in features:
32
+ embeddings.append(embed([rec[f]]))
33
+ return np.hstack(embeddings)
34
+
35
+
36
+ def main():
37
+ model = get_model()
38
+ record = record_get()
39
+ embeds = record_embed(record)
40
+ (pred,) = model.predict(embeds)
41
+ print(record["content"])
42
+ print(f"is news (real): {record['is_news_article']}")
43
+ print(f"is news (pred): {pred}")
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()