Ali Kefia
commited on
Commit
·
6505378
1
Parent(s):
4195791
first version
Browse files- .gitattributes +1 -34
- .gitignore +1 -0
- .mise.toml +21 -0
- README.md +90 -0
- data/eval.csv +0 -0
- data/train.csv +0 -0
- embed.py +34 -0
- model/model.pickle +3 -0
- out/confusion_matrix.png +0 -0
- out/preds.csv +45 -0
- out/roc_curve.png +0 -0
- requirements.dev.txt +1 -0
- requirements.txt +6 -0
- ruff.toml +2 -0
- train.py +120 -0
.gitattributes
CHANGED
@@ -1,35 +1,2 @@
|
|
1 |
-
*.
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
.mise.toml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[env]
|
2 |
+
_.python.venv = ".venv"
|
3 |
+
|
4 |
+
EMBEDDING_MODEL = "Snowflake/snowflake-arctic-embed-xs"
|
5 |
+
EMBEDDING_MODEL_REV = "d8c86521100d3556476a063fc2342036d45c106f"
|
6 |
+
|
7 |
+
DATA_DIR = "{{config_root}}/data"
|
8 |
+
MODEL_DIR = "{{config_root}}/model"
|
9 |
+
OUT_DIR = "{{config_root}}/out"
|
10 |
+
|
11 |
+
[tasks.deps]
|
12 |
+
run = [
|
13 |
+
"uv pip install -r {{config_root}}/requirements.txt",
|
14 |
+
"uv pip install -r {{config_root}}/requirements.dev.txt",
|
15 |
+
]
|
16 |
+
|
17 |
+
[tasks."code:fmt"]
|
18 |
+
run = "ruff format {{config_root}}"
|
19 |
+
|
20 |
+
[tasks."code:lint"]
|
21 |
+
run = "ruff check --fix {{config_root}}"
|
README.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🧠 Article Relevance Classifier (Prototype)
|
2 |
+
|
3 |
+
This project aims to classify news articles as **relevant** (i.e., discussing a new event) or **non-relevant**. The articles are then provided to an LLM pipeline. We should maximize the lowest false positive rate as we don't want the LLMs to be polluted.
|
4 |
+
|
5 |
+
|
6 |
+
## 🧾 Available Features
|
7 |
+
|
8 |
+
For each article, we collect a set of features from both the metadata and the raw content of the web page:
|
9 |
+
|
10 |
+
- **Metadata Title**: The `<title>` tag of the page, often used by browsers and search engines.
|
11 |
+
- **Metadata Description**: The `<meta name="description">` field, typically summarizing the article content.
|
12 |
+
- **Content**: The main textual content of the article, extracted using [trafilatura](https://github.com/adbar/trafilatura).
|
13 |
+
- **Date**: The publication date of the article (when available).
|
14 |
+
- **CSS Title**: A title found in the visible content, usually marked with large or header-style HTML tags (e.g., `<h1>`).
|
15 |
+
- **og:type**: The Open Graph `og:type` property, which often indicates the type of content (e.g., `article`, `video`, `website`).
|
16 |
+
- **Text-to-HTML Ratio**: The ratio between the length of the extracted text and the total HTML size, indicating how content-focused the page is.
|
17 |
+
- **Paragraph Count**: The number of `<p>` tags, giving a rough idea of how much structured text the page contains.
|
18 |
+
- **Link Count**: The total number of hyperlinks (`<a>` tags) on the page.
|
19 |
+
- **Weekday**: The day of the week the article was published, which can help identify publishing patterns.
|
20 |
+
- **Average Link Count of the Website**: The average number of hyperlinks per page across the entire website domain. This helps differentiate content-heavy domains from link-heavy or index-style sites.
|
21 |
+
|
22 |
+
➡️ The **Link Count** feature becomes more meaningful when **compared to the Average Link Count of the Website**. For example, a page with very few links on a generally link-heavy site may indicate that it is an article rather than a hub or landing page.
|
23 |
+
|
24 |
+
⚠️ However, **computing the Average Link Count of the Website requires crawling multiple pages of the same domain**, which is not feasible in a real-time prediction setting (i.e., when you want to classify a single article instantly). For this reason, features like **Average Link Count of the Website** can only be used during offline training and are not available at inference time.
|
25 |
+
|
26 |
+
## 🔍 Approach
|
27 |
+
|
28 |
+
For this first prototype, I decided to use **text embeddings** to semantically represent each article. These embeddings are then used to train a binary classifier.
|
29 |
+
|
30 |
+
Each article is represented using three components:
|
31 |
+
- The **title** (from metadata, up to 512 tokens),
|
32 |
+
- The **description** (from metadata, up to 512 tokens),
|
33 |
+
- The **main content** (extracted using [trafilatura](https://github.com/adbar/trafilatura), up to 512 tokens).
|
34 |
+
|
35 |
+
⚠️ Due to token limits, only the beginning of each text field is used. This may affect classification performance when relevant information appears later in the article.
|
36 |
+
|
37 |
+
For the classifier, I chose a **Support Vector Machine (SVM)** model because:
|
38 |
+
- **K-Nearest Neighbors (KNN)** is too slow at inference time due to the high dimensionality (512 × 3 features),
|
39 |
+
- **Random Forests** risk overfitting when dealing with a large number of features,
|
40 |
+
- **Logistic Regression** is a viable alternative, but SVMs generally perform better on high-dimensional, sparse datasets.
|
41 |
+
|
42 |
+
### 📊 Results (Test 1)
|
43 |
+
|
44 |
+
Below is the confusion matrix for the first test:
|
45 |
+
|
46 |
+

|
47 |
+
|
48 |
+
- **Accuracy**: (15 + 18) / (15 + 7 + 4 + 18) = **72.5%**
|
49 |
+
- **Precision (Relevant)**: 18 / (18 + 7) = **72.0%**
|
50 |
+
- **Recall (Relevant)**: 18 / (18 + 4) = **81.8%**
|
51 |
+
|
52 |
+
These initial results suggest the model can already capture some meaningful signals, although there is room for improvement, especially in reducing false positives.
|
53 |
+
|
54 |
+
## 🧪 Second Test: Chunked Embeddings with Averaging
|
55 |
+
|
56 |
+
To address the limitations of the first test, a second experiment was conducted using **chunked embeddings**:
|
57 |
+
|
58 |
+
- Instead of truncating the text at 512 tokens, each field (title, description, content) is **split into multiple chunks** of up to 512 tokens.
|
59 |
+
- Each chunk is embedded separately.
|
60 |
+
- The final representation is the **average of all the chunk embeddings**.
|
61 |
+
|
62 |
+
This method allows the model to consider a **broader portion of the article**, potentially capturing relevant information that appears later in the text.
|
63 |
+
|
64 |
+
The goal of this second test is to evaluate whether this approach improves classification performance compared to the truncated version.
|
65 |
+
|
66 |
+
### 📊 Results (Test 2)
|
67 |
+
|
68 |
+
Below is the confusion matrix for the second test:
|
69 |
+
|
70 |
+

|
71 |
+
|
72 |
+
- **Accuracy**: (16 + 18) / (16 + 6 + 4 + 18) = **77.3%**
|
73 |
+
- **Precision (Relevant)**: 18 / (18 + 6) = **75.0%**
|
74 |
+
- **Recall (Relevant)**: 18 / (18 + 4) = **81.8%**
|
75 |
+
|
76 |
+
Compared to Test 1, this version shows a slight improvement in both **accuracy** and **precision**, indicating that including more of the article content via chunked embeddings helps reduce false positives and better capture relevant information.
|
77 |
+
|
78 |
+
#### 📈 ROC Curve
|
79 |
+
|
80 |
+
For this test, I also generated the ROC curve:
|
81 |
+
|
82 |
+

|
83 |
+
|
84 |
+
The curve appears to have a stepped shape, which is expected due to the **limited number of test samples**. As a result, it’s difficult to draw strong conclusions from the ROC curve alone.
|
85 |
+
|
86 |
+
However, we may tentatively observe that **lowering the decision threshold** could help reduce false positives — a promising direction to explore in future experiments with more data.
|
87 |
+
|
88 |
+
#### Analysis of the results
|
89 |
+
|
90 |
+
After analyzing the results, it seems that the model has difficulty distinguishing between the content types of articles, specifically whether they are news or not. However, it excels at identifying the structural layout of pages, such as determining if a page is a homepage, article, video, etc. Therefore, adding features like og:type, text-to-HTML ratio, and paragraph count may not be beneficial, as these features are primarily useful for differentiating page structure rather than content type.
|
data/eval.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
embed.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import cache
|
3 |
+
from itertools import batched
|
4 |
+
from typing import Generator, Iterator
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from numpy.typing import NDArray
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
|
10 |
+
|
11 |
+
def split(text: str, max_tokens: int = 512) -> Generator[str, None, None]:
|
12 |
+
# Naive approach - use opale internal chunking techniques (special tokens count)
|
13 |
+
words = text.split()
|
14 |
+
if not words:
|
15 |
+
return
|
16 |
+
for batch in batched(words, max_tokens // 2): # Assuming 2 tokens per word
|
17 |
+
yield " ".join(batch)
|
18 |
+
|
19 |
+
|
20 |
+
@cache
|
21 |
+
def get_model():
|
22 |
+
return SentenceTransformer(
|
23 |
+
os.environ["EMBEDDING_MODEL"], revision=os.environ["EMBEDDING_MODEL_REV"]
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def embed(texts: Iterator[str], max_tokens: int = 512) -> NDArray:
|
28 |
+
res: list[NDArray] = []
|
29 |
+
for text in texts:
|
30 |
+
embeddings = get_model().encode(
|
31 |
+
list(split(text, max_tokens)), show_progress_bar=False
|
32 |
+
)
|
33 |
+
res.append(np.mean(embeddings, axis=0))
|
34 |
+
return np.array(res)
|
model/model.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eef2e93850a38b0e98b68ac0e0a32f4a67408e7327a7646ed9e96019c3dc7583
|
3 |
+
size 5293480
|
out/confusion_matrix.png
ADDED
![]() |
out/preds.csv
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
url,is_news_article,prediction,is_prediction_correct
|
2 |
+
https://quantumcomputingreport.com/quandela-launches-belenos-photonic-quantum-computer-with-doubling-of-qubit-count-and-4000x-power-increase/,true,true,true
|
3 |
+
https://www.nqcc.ac.uk/,false,false,true
|
4 |
+
https://quantumcomputingreport.com/qsensato-raises-e500k-560k-usd-to-advance-integrated-atomic-quantum-sensors-for-precision-sensing/,true,true,true
|
5 |
+
https://quantumcomputingreport.com/zurich-instruments-and-rohde-schwarz-join-australias-national-quantum-computing-testbed-facility/,true,true,true
|
6 |
+
https://quantumcomputingreport.com/hbku-launches-qatars-first-quantum-computing-laboratory-backed-by-10m-mod-grant/,true,true,true
|
7 |
+
https://quantumcomputingreport.com/quantinuum-releases-%ce%bbambeq-gen-ii-for-scalable-interpretable-quantum-nlp/,true,false,false
|
8 |
+
https://quantumcomputingreport.com/quobly-secures-e21m-23-7m-usd-to-industrialize-100-qubit-silicon-quantum-processor/,true,true,true
|
9 |
+
https://quantumcomputingreport.com/semiqon-and-nanoacademic-partner-to-advance-silicon-spin-qubit-research-and-education/,true,true,true
|
10 |
+
https://quantumcomputingreport.com/united-nations-itu-launches-quantum-for-good-to-align-innovation-with-global-impact/,true,false,false
|
11 |
+
https://quantumcomputingreport.com/microsoft-adds-post-quantum-cryptography-to-windows-insider-builds-and-linux/,true,true,true
|
12 |
+
https://www.nqcc.ac.uk/technology-and-research/our-research/,false,false,true
|
13 |
+
https://quantumcomputingreport.com/podcast-with-scott-davis-ceo-and-co-founder-of-vescent/,false,false,true
|
14 |
+
https://quantumzeitgeist.com/building-atoms-the-rise-of-nanotechnology-and-molecular-engineering/,false,true,false
|
15 |
+
https://quantumzeitgeist.com/networked-services-technologies-applications-and-challenges-for-advanced-communication/,false,false,true
|
16 |
+
https://quantumzeitgeist.com/amazon-braket-sdk-and-multi-platform-quantum-development/,false,true,false
|
17 |
+
https://quantumzeitgeist.com/pennylane-and-quantum-machine-learning/,false,false,true
|
18 |
+
https://quantumzeitgeist.com/quantum-physics-meets-spiritual-philosophy-exploring-the-intersection-of-string-theory-and-consciousness/,false,false,true
|
19 |
+
https://quantumzeitgeist.com/quantum-computing-transforms-financial-derivatives-pricing-for-complex-options-and-risk-analysis/,false,true,false
|
20 |
+
https://quantumzeitgeist.com/quantifying-quantum-correlations-in-symmetric-gaussian-states-with-universal-invariants/,true,false,false
|
21 |
+
https://www.horseandhound.co.uk/news/horse-life-threatening-stomach-tumour-saved-pioneering-surgery-894298,true,true,true
|
22 |
+
https://www.maddyness.com/2025/06/02/vivatech-startups-deals-annonces-ce-que-la-mission-french-tech-prevoit-pour-levenement/,false,false,true
|
23 |
+
https://www.cbsnews.com/sanfrancisco/news/padel-a-fast-growing-sport-has-become-a-new-obsession-for-silicon-valley/,false,true,false
|
24 |
+
https://www.cloudcomputing-news.net/news/microsoft-launches-its-first-cloud-region-in-malaysia/,true,true,true
|
25 |
+
https://padelmagazine.fr/best-padel-racket-awards-2025-les-meilleures-raquettes-de-lannee-devoilees/,false,false,true
|
26 |
+
https://www.horseandhound.co.uk/news/polly-dickson-obituary-894506,true,true,true
|
27 |
+
https://www.homeselect.paris/en/blog/devenir-proprietaire,false,false,true
|
28 |
+
https://www.maddyness.com/2020/10/23/salomon-aiach-interview-facebook-startups/,false,false,true
|
29 |
+
https://www.solarpowerportal.co.uk/grid-operators-must-work-together-in-aftermath-of-spain-and-portugal-blackout/,false,true,false
|
30 |
+
https://www.cloudcomputing-news.net/news/podcast/nginx-f5-api-proxy-podcast-apac-sprint-two-point-one-podcast-s02-e30/,false,false,true
|
31 |
+
https://www.farminguk.com/news/vegan-activists-attempt-to-shut-down-royal-highland-parade_66662.html,true,true,true
|
32 |
+
https://dairynews.today/news/world_milk_day_2025_health_innovation_and_sustainability_drive_india_s_milk_movement_9339211.html,false,true,false
|
33 |
+
"https://lerail.com/news/95810-signature-du-second-appel-%C3%A0-projets-gares-de-demain-entre-la-r%C3%A9gion-%C3%AEle-de-france,-%C3%AEle-de-france-mobilit%C3%A9s-et-sncf-gares-connexions",true,false,false
|
34 |
+
https://lerail.com/news/95984-drive-to-zero-2025,false,false,true
|
35 |
+
https://www.horseandhound.co.uk/news/farewell-to-twinshock-warrior-894106,true,true,true
|
36 |
+
https://www.farminguk.com/news/new-ai-driven-test-targets-silent-killer-in-uk-cattle_66604.html,true,true,true
|
37 |
+
https://www.maddyness.com/2019/05/02/growthhacking-chahab-nastar-scaleups/,false,false,true
|
38 |
+
https://www.businesstravelnews.com/Lodging/Hyatt-Creates-New-Unscripted-Collection-Brand,true,false,false
|
39 |
+
https://meuble-info.fr/falmec-gessi-le-duo-gagnant-du-point-deau/,true,false,false
|
40 |
+
https://www.cloudcomputing-news.net/news/podcast/supply-chain-automation-warehousing-distribution-rpa-best-dematic-podcast-s03-e10/,false,false,true
|
41 |
+
https://www.maddyness.com/2025/05/06/mon-petit-placement-tombe-dans-le-giron-de-malakoff-humanis/,true,false,false
|
42 |
+
https://lerail.com/technical-articles/79770-southco-s%C3%A9curisation-du-v%C3%A9hicule-%C3%A9lectrique-infrastructure-de-recharge-et-de-stockage-sur-batterie-de-r%C3%A9seau,false,false,true
|
43 |
+
https://www.watches-news.com/alpine-eagle-41-xp-cs-platinum/,true,true,true
|
44 |
+
https://www.imarcgroup.com/football-market,false,true,false
|
45 |
+
https://www.constructionnews.co.uk/contractors/balfour-beatty/balfour-beatty-court-battle-over-serious-trucks-cartel-ends-17-01-2025/,true,true,true
|
out/roc_curve.png
ADDED
![]() |
requirements.dev.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python-lsp-server
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
numpy
|
3 |
+
polars
|
4 |
+
scikit-learn
|
5 |
+
seaborn
|
6 |
+
sentence-transformers
|
ruff.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[lint]
|
2 |
+
extend-select = ["I"]
|
train.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import polars as pl
|
9 |
+
import seaborn as sns
|
10 |
+
from numpy.typing import NDArray
|
11 |
+
from polars import DataFrame
|
12 |
+
from sklearn.metrics import auc, confusion_matrix, roc_curve
|
13 |
+
from sklearn.svm import SVC
|
14 |
+
|
15 |
+
from embed import embed as _embed
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
logging.basicConfig(level=logging.INFO)
|
20 |
+
|
21 |
+
DATA = Path(os.environ["DATA_DIR"])
|
22 |
+
DATA.mkdir(parents=True, exist_ok=True)
|
23 |
+
MODEL = Path(os.environ["MODEL_DIR"])
|
24 |
+
MODEL.mkdir(parents=True, exist_ok=True)
|
25 |
+
OUT = Path(os.environ["OUT_DIR"])
|
26 |
+
OUT.mkdir(parents=True, exist_ok=True)
|
27 |
+
|
28 |
+
|
29 |
+
def embed(df: DataFrame):
|
30 |
+
logger.info(f"embed start {df.height}")
|
31 |
+
features = ["content", "meta_title", "meta_description"]
|
32 |
+
embeddings = []
|
33 |
+
for col in features:
|
34 |
+
train_texts = df.select(col).to_series().to_list()
|
35 |
+
embeddings.append(_embed(train_texts))
|
36 |
+
res = np.hstack(embeddings)
|
37 |
+
logger.info(f"embed done {res.shape}")
|
38 |
+
return res
|
39 |
+
|
40 |
+
|
41 |
+
def train(df: DataFrame, target: str):
|
42 |
+
logger.info(f"train start {df.height}")
|
43 |
+
X = embed(df)
|
44 |
+
y = df.select(target).to_numpy().ravel()
|
45 |
+
clf = SVC(kernel="linear", probability=True)
|
46 |
+
clf.fit(X, y)
|
47 |
+
logger.info("train done")
|
48 |
+
return clf
|
49 |
+
|
50 |
+
|
51 |
+
def save_prediction(eval_df: DataFrame, y_eval: NDArray, y_pred: NDArray) -> None:
|
52 |
+
pl.DataFrame(
|
53 |
+
{
|
54 |
+
"url": eval_df.select("url").to_series().to_list(),
|
55 |
+
"is_news_article": y_eval,
|
56 |
+
"prediction": y_pred,
|
57 |
+
"is_prediction_correct": y_eval == y_pred,
|
58 |
+
}
|
59 |
+
).write_csv(OUT / "preds.csv")
|
60 |
+
|
61 |
+
|
62 |
+
def save_roc_curve(clf, X: NDArray, y: NDArray):
|
63 |
+
probs = clf.predict_proba(X)[:, 1] # Probability for the positive class
|
64 |
+
fpr, tpr, thresholds = roc_curve(y, probs)
|
65 |
+
roc_auc = auc(fpr, tpr)
|
66 |
+
|
67 |
+
plt.figure(figsize=(6, 5))
|
68 |
+
plt.plot(
|
69 |
+
fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})"
|
70 |
+
)
|
71 |
+
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
|
72 |
+
plt.xlim([0.0, 1.0])
|
73 |
+
plt.ylim([0.0, 1.05])
|
74 |
+
plt.xlabel("False Positive Rate")
|
75 |
+
plt.ylabel("True Positive Rate")
|
76 |
+
plt.title("Receiver Operating Characteristic (ROC)")
|
77 |
+
plt.legend(loc="lower right")
|
78 |
+
plt.tight_layout()
|
79 |
+
plt.savefig(OUT / "roc_curve.png")
|
80 |
+
plt.close()
|
81 |
+
|
82 |
+
|
83 |
+
def save_confusion_matrix(y: NDArray, pred: NDArray):
|
84 |
+
plt.figure(figsize=(5, 4))
|
85 |
+
sns.heatmap(
|
86 |
+
confusion_matrix(y, pred),
|
87 |
+
annot=True,
|
88 |
+
fmt="d",
|
89 |
+
cmap="Blues",
|
90 |
+
xticklabels=["Not Relevant", "Relevant"],
|
91 |
+
yticklabels=["Not Relevant", "Relevant"],
|
92 |
+
)
|
93 |
+
plt.xlabel("Predicted")
|
94 |
+
plt.ylabel("Actual")
|
95 |
+
plt.title("Confusion Matrix")
|
96 |
+
plt.tight_layout()
|
97 |
+
plt.savefig(OUT / "confusion_matrix.png")
|
98 |
+
plt.close()
|
99 |
+
|
100 |
+
|
101 |
+
def main() -> None:
|
102 |
+
target = "is_news_article"
|
103 |
+
train_df = pl.read_csv(DATA / "train.csv")
|
104 |
+
clf = train(train_df, target)
|
105 |
+
with open(MODEL / "model.pickle", "wb") as f:
|
106 |
+
pickle.dump(clf, f)
|
107 |
+
|
108 |
+
eval_df = pl.read_csv(DATA / "eval.csv")
|
109 |
+
logger.info(f"eval start {eval_df.height}")
|
110 |
+
eval_X = embed(eval_df)
|
111 |
+
eval_y = eval_df.select(target).to_numpy().ravel()
|
112 |
+
eval_pred = clf.predict(eval_X)
|
113 |
+
save_prediction(eval_df, eval_y, eval_pred)
|
114 |
+
save_confusion_matrix(eval_y, eval_pred)
|
115 |
+
save_roc_curve(clf, eval_X, eval_y)
|
116 |
+
logger.info("eval done")
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
main()
|