Upload inference.py with huggingface_hub
Browse files- inference.py +100 -0
inference.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
inference.py
|
4 |
+
|
5 |
+
This script loads the trained sentiment classification model and tokenizer,
|
6 |
+
prompts the user for a review text, and outputs the predicted sentiment
|
7 |
+
(Positive or Negative) with a styled Rich UI.
|
8 |
+
|
9 |
+
Usage:
|
10 |
+
python inference.py
|
11 |
+
# Then follow the prompt to enter your review.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import os
|
15 |
+
import torch
|
16 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
17 |
+
from rich.console import Console
|
18 |
+
from rich.panel import Panel
|
19 |
+
from rich.table import Table
|
20 |
+
from rich.prompt import Prompt
|
21 |
+
from rich.markdown import Markdown
|
22 |
+
from rich import box
|
23 |
+
from rich.status import Status
|
24 |
+
|
25 |
+
# Configure console
|
26 |
+
console = Console()
|
27 |
+
|
28 |
+
# Since model files are now in the current directory, we set MODEL_DIR to "."
|
29 |
+
MODEL_DIR = "."
|
30 |
+
|
31 |
+
def predict_sentiment(review_text: str, model, tokenizer):
|
32 |
+
"""Predict sentiment (Positive/Negative) for a given review using the loaded model and tokenizer."""
|
33 |
+
inputs = tokenizer(review_text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
34 |
+
with torch.no_grad():
|
35 |
+
outputs = model(**inputs)
|
36 |
+
logits = outputs.logits
|
37 |
+
probs = torch.softmax(logits, dim=1)
|
38 |
+
pred_class = torch.argmax(probs, dim=1).item()
|
39 |
+
# pred_class: 0 = Negative, 1 = Positive
|
40 |
+
return pred_class, probs[0].tolist()
|
41 |
+
|
42 |
+
def main():
|
43 |
+
console.rule("[bold magenta]Steam Review Sentiment Inference[/bold magenta]")
|
44 |
+
|
45 |
+
# Intro message
|
46 |
+
intro = Markdown(
|
47 |
+
"""
|
48 |
+
**Welcome!**
|
49 |
+
This tool uses a fine-tuned DistilBERT model to predict whether a given Steam review is *Positive* or *Negative*.
|
50 |
+
|
51 |
+
- Enter a review below and press [bold green]Enter[/bold green].
|
52 |
+
- The model will run inference and display the sentiment prediction.
|
53 |
+
"""
|
54 |
+
)
|
55 |
+
console.print(intro)
|
56 |
+
|
57 |
+
# Prompt user for input review
|
58 |
+
review = Prompt.ask("[bold cyan]Please enter the Steam review text[/bold cyan]", default="This game is amazing!")
|
59 |
+
if not review.strip():
|
60 |
+
console.print("[red]No input provided. Exiting.[/red]")
|
61 |
+
return
|
62 |
+
|
63 |
+
# Check if model directory exists
|
64 |
+
if not os.path.isdir(MODEL_DIR):
|
65 |
+
console.print(f"[red]Model directory not found at: {MODEL_DIR}[/red]")
|
66 |
+
return
|
67 |
+
|
68 |
+
console.print("\n[bold yellow]Loading model and tokenizer...[/bold yellow]")
|
69 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
70 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
|
71 |
+
|
72 |
+
console.print("\n[bold green]Running inference...[/bold green]")
|
73 |
+
with console.status("[bold blue]Thinking...[/bold blue]", spinner="dots"):
|
74 |
+
pred_class, probabilities = predict_sentiment(review, model, tokenizer)
|
75 |
+
|
76 |
+
sentiment_label = "Positive" if pred_class == 1 else "Negative"
|
77 |
+
pos_prob = probabilities[1]
|
78 |
+
neg_prob = probabilities[0]
|
79 |
+
|
80 |
+
# Create a table for probabilities, using `box.ROUNDED`
|
81 |
+
table = Table(title="Sentiment Probabilities", box=box.ROUNDED, expand=False)
|
82 |
+
table.add_column("Sentiment", style="bold cyan", justify="center")
|
83 |
+
table.add_column("Probability", style="bold magenta", justify="center")
|
84 |
+
table.add_row("Positive", f"{pos_prob:.4f}")
|
85 |
+
table.add_row("Negative", f"{neg_prob:.4f}")
|
86 |
+
|
87 |
+
# Create a panel for the final output
|
88 |
+
output_panel = Panel(
|
89 |
+
table,
|
90 |
+
title=f"Predicted Sentiment: [bold green]{sentiment_label}[/bold green]",
|
91 |
+
subtitle="Inference Complete",
|
92 |
+
border_style="bold magenta"
|
93 |
+
)
|
94 |
+
|
95 |
+
console.rule("[bold magenta]Inference Result[/bold magenta]")
|
96 |
+
console.print(output_panel)
|
97 |
+
console.rule()
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
main()
|