import logging import pickle import matplotlib.pyplot as plt import polars as pl import seaborn as sns from numpy.typing import NDArray from sklearn.metrics import auc, confusion_matrix, roc_curve from sklearn.svm import SVC from utils.paths import DATA, IMGS, MODEL logging.basicConfig(level=logging.INFO) def save_roc_curve(clf, X: NDArray, y: NDArray): probs = clf.predict_proba(X)[:, 1] # Probability for the positive class fpr, tpr, thresholds = roc_curve(y, probs) roc_auc = auc(fpr, tpr) plt.figure(figsize=(6, 5)) plt.plot( fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})" ) plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.title("Receiver Operating Characteristic (ROC)") plt.legend(loc="lower right") plt.tight_layout() plt.savefig(IMGS / "roc_curve.png") plt.close() def save_confusion_matrix(y: NDArray, pred: NDArray): plt.figure(figsize=(5, 4)) sns.heatmap( confusion_matrix(y, pred), annot=True, fmt="d", cmap="Blues", xticklabels=["Not Relevant", "Relevant"], yticklabels=["Not Relevant", "Relevant"], ) plt.xlabel("Predicted") plt.ylabel("Actual") plt.title("Confusion Matrix") plt.tight_layout() plt.savefig(IMGS / "confusion_matrix.png") plt.close() def main() -> None: train_df = pl.read_parquet(DATA / "train.parquet") clf = SVC(kernel="poly", probability=True) clf.fit( train_df.get_column("embeds").to_numpy(), train_df.get_column("is_news").to_numpy(), ) with open(MODEL / "model.pickle", "wb") as f: pickle.dump(clf, f) eval_df = pl.read_parquet(DATA / "eval.parquet") eval_X = eval_df.get_column("embeds").to_numpy() eval_y = eval_df.get_column("is_news").to_numpy() eval_pred = clf.predict(eval_X) save_confusion_matrix(eval_y, eval_pred) save_roc_curve(clf, eval_X, eval_y) if __name__ == "__main__": main()