ericsonwillians's picture
Upload inference.py with huggingface_hub
9947e88 verified
raw
history blame
3.62 kB
#!/usr/bin/env python3
"""
inference.py
This script loads the trained sentiment classification model and tokenizer,
prompts the user for a review text, and outputs the predicted sentiment
(Positive or Negative) with a styled Rich UI.
Usage:
python inference.py
# Then follow the prompt to enter your review.
"""
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.prompt import Prompt
from rich.markdown import Markdown
from rich import box
from rich.status import Status
# Configure console
console = Console()
# Since model files are now in the current directory, we set MODEL_DIR to "."
MODEL_DIR = "."
def predict_sentiment(review_text: str, model, tokenizer):
"""Predict sentiment (Positive/Negative) for a given review using the loaded model and tokenizer."""
inputs = tokenizer(review_text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
# pred_class: 0 = Negative, 1 = Positive
return pred_class, probs[0].tolist()
def main():
console.rule("[bold magenta]Steam Review Sentiment Inference[/bold magenta]")
# Intro message
intro = Markdown(
"""
**Welcome!**
This tool uses a fine-tuned DistilBERT model to predict whether a given Steam review is *Positive* or *Negative*.
- Enter a review below and press [bold green]Enter[/bold green].
- The model will run inference and display the sentiment prediction.
"""
)
console.print(intro)
# Prompt user for input review
review = Prompt.ask("[bold cyan]Please enter the Steam review text[/bold cyan]", default="This game is amazing!")
if not review.strip():
console.print("[red]No input provided. Exiting.[/red]")
return
# Check if model directory exists
if not os.path.isdir(MODEL_DIR):
console.print(f"[red]Model directory not found at: {MODEL_DIR}[/red]")
return
console.print("\n[bold yellow]Loading model and tokenizer...[/bold yellow]")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
console.print("\n[bold green]Running inference...[/bold green]")
with console.status("[bold blue]Thinking...[/bold blue]", spinner="dots"):
pred_class, probabilities = predict_sentiment(review, model, tokenizer)
sentiment_label = "Positive" if pred_class == 1 else "Negative"
pos_prob = probabilities[1]
neg_prob = probabilities[0]
# Create a table for probabilities, using `box.ROUNDED`
table = Table(title="Sentiment Probabilities", box=box.ROUNDED, expand=False)
table.add_column("Sentiment", style="bold cyan", justify="center")
table.add_column("Probability", style="bold magenta", justify="center")
table.add_row("Positive", f"{pos_prob:.4f}")
table.add_row("Negative", f"{neg_prob:.4f}")
# Create a panel for the final output
output_panel = Panel(
table,
title=f"Predicted Sentiment: [bold green]{sentiment_label}[/bold green]",
subtitle="Inference Complete",
border_style="bold magenta"
)
console.rule("[bold magenta]Inference Result[/bold magenta]")
console.print(output_panel)
console.rule()
if __name__ == "__main__":
main()