news-classifier / train.py
Ali Kefia
linear -> poly
f2f47ac
raw
history blame
2.13 kB
import logging
import pickle
import matplotlib.pyplot as plt
import polars as pl
import seaborn as sns
from numpy.typing import NDArray
from sklearn.metrics import auc, confusion_matrix, roc_curve
from sklearn.svm import SVC
from utils.paths import DATA, IMGS, MODEL
logging.basicConfig(level=logging.INFO)
def save_roc_curve(clf, X: NDArray, y: NDArray):
probs = clf.predict_proba(X)[:, 1] # Probability for the positive class
fpr, tpr, thresholds = roc_curve(y, probs)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6, 5))
plt.plot(
fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})"
)
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC)")
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig(IMGS / "roc_curve.png")
plt.close()
def save_confusion_matrix(y: NDArray, pred: NDArray):
plt.figure(figsize=(5, 4))
sns.heatmap(
confusion_matrix(y, pred),
annot=True,
fmt="d",
cmap="Blues",
xticklabels=["Not Relevant", "Relevant"],
yticklabels=["Not Relevant", "Relevant"],
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig(IMGS / "confusion_matrix.png")
plt.close()
def main() -> None:
train_df = pl.read_parquet(DATA / "train.parquet")
clf = SVC(kernel="poly", probability=True)
clf.fit(
train_df.get_column("embeds").to_numpy(),
train_df.get_column("is_news").to_numpy(),
)
with open(MODEL / "model.pickle", "wb") as f:
pickle.dump(clf, f)
eval_df = pl.read_parquet(DATA / "eval.parquet")
eval_X = eval_df.get_column("embeds").to_numpy()
eval_y = eval_df.get_column("is_news").to_numpy()
eval_pred = clf.predict(eval_X)
save_confusion_matrix(eval_y, eval_pred)
save_roc_curve(clf, eval_X, eval_y)
if __name__ == "__main__":
main()