|
|
|
""" |
|
inference.py |
|
|
|
This script loads the trained sentiment classification model and tokenizer, |
|
prompts the user for a review text, and outputs the predicted sentiment |
|
(Positive or Negative) with a styled Rich UI. |
|
|
|
Usage: |
|
python inference.py |
|
# Then follow the prompt to enter your review. |
|
""" |
|
|
|
import os |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from rich.console import Console |
|
from rich.panel import Panel |
|
from rich.table import Table |
|
from rich.prompt import Prompt |
|
from rich.markdown import Markdown |
|
from rich import box |
|
from rich.status import Status |
|
|
|
|
|
console = Console() |
|
|
|
|
|
MODEL_DIR = "." |
|
|
|
def predict_sentiment(review_text: str, model, tokenizer): |
|
"""Predict sentiment (Positive/Negative) for a given review using the loaded model and tokenizer.""" |
|
inputs = tokenizer(review_text, return_tensors="pt", truncation=True, padding="max_length", max_length=128) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probs = torch.softmax(logits, dim=1) |
|
pred_class = torch.argmax(probs, dim=1).item() |
|
|
|
return pred_class, probs[0].tolist() |
|
|
|
def main(): |
|
console.rule("[bold magenta]Steam Review Sentiment Inference[/bold magenta]") |
|
|
|
|
|
intro = Markdown( |
|
""" |
|
**Welcome!** |
|
This tool uses a fine-tuned DistilBERT model to predict whether a given Steam review is *Positive* or *Negative*. |
|
|
|
- Enter a review below and press [bold green]Enter[/bold green]. |
|
- The model will run inference and display the sentiment prediction. |
|
""" |
|
) |
|
console.print(intro) |
|
|
|
|
|
review = Prompt.ask("[bold cyan]Please enter the Steam review text[/bold cyan]", default="This game is amazing!") |
|
if not review.strip(): |
|
console.print("[red]No input provided. Exiting.[/red]") |
|
return |
|
|
|
|
|
if not os.path.isdir(MODEL_DIR): |
|
console.print(f"[red]Model directory not found at: {MODEL_DIR}[/red]") |
|
return |
|
|
|
console.print("\n[bold yellow]Loading model and tokenizer...[/bold yellow]") |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) |
|
|
|
console.print("\n[bold green]Running inference...[/bold green]") |
|
with console.status("[bold blue]Thinking...[/bold blue]", spinner="dots"): |
|
pred_class, probabilities = predict_sentiment(review, model, tokenizer) |
|
|
|
sentiment_label = "Positive" if pred_class == 1 else "Negative" |
|
pos_prob = probabilities[1] |
|
neg_prob = probabilities[0] |
|
|
|
|
|
table = Table(title="Sentiment Probabilities", box=box.ROUNDED, expand=False) |
|
table.add_column("Sentiment", style="bold cyan", justify="center") |
|
table.add_column("Probability", style="bold magenta", justify="center") |
|
table.add_row("Positive", f"{pos_prob:.4f}") |
|
table.add_row("Negative", f"{neg_prob:.4f}") |
|
|
|
|
|
output_panel = Panel( |
|
table, |
|
title=f"Predicted Sentiment: [bold green]{sentiment_label}[/bold green]", |
|
subtitle="Inference Complete", |
|
border_style="bold magenta" |
|
) |
|
|
|
console.rule("[bold magenta]Inference Result[/bold magenta]") |
|
console.print(output_panel) |
|
console.rule() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|