ericsonwillians commited on
Commit
9947e88
·
verified ·
1 Parent(s): b1fe6f9

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +100 -0
inference.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ inference.py
4
+
5
+ This script loads the trained sentiment classification model and tokenizer,
6
+ prompts the user for a review text, and outputs the predicted sentiment
7
+ (Positive or Negative) with a styled Rich UI.
8
+
9
+ Usage:
10
+ python inference.py
11
+ # Then follow the prompt to enter your review.
12
+ """
13
+
14
+ import os
15
+ import torch
16
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
17
+ from rich.console import Console
18
+ from rich.panel import Panel
19
+ from rich.table import Table
20
+ from rich.prompt import Prompt
21
+ from rich.markdown import Markdown
22
+ from rich import box
23
+ from rich.status import Status
24
+
25
+ # Configure console
26
+ console = Console()
27
+
28
+ # Since model files are now in the current directory, we set MODEL_DIR to "."
29
+ MODEL_DIR = "."
30
+
31
+ def predict_sentiment(review_text: str, model, tokenizer):
32
+ """Predict sentiment (Positive/Negative) for a given review using the loaded model and tokenizer."""
33
+ inputs = tokenizer(review_text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ logits = outputs.logits
37
+ probs = torch.softmax(logits, dim=1)
38
+ pred_class = torch.argmax(probs, dim=1).item()
39
+ # pred_class: 0 = Negative, 1 = Positive
40
+ return pred_class, probs[0].tolist()
41
+
42
+ def main():
43
+ console.rule("[bold magenta]Steam Review Sentiment Inference[/bold magenta]")
44
+
45
+ # Intro message
46
+ intro = Markdown(
47
+ """
48
+ **Welcome!**
49
+ This tool uses a fine-tuned DistilBERT model to predict whether a given Steam review is *Positive* or *Negative*.
50
+
51
+ - Enter a review below and press [bold green]Enter[/bold green].
52
+ - The model will run inference and display the sentiment prediction.
53
+ """
54
+ )
55
+ console.print(intro)
56
+
57
+ # Prompt user for input review
58
+ review = Prompt.ask("[bold cyan]Please enter the Steam review text[/bold cyan]", default="This game is amazing!")
59
+ if not review.strip():
60
+ console.print("[red]No input provided. Exiting.[/red]")
61
+ return
62
+
63
+ # Check if model directory exists
64
+ if not os.path.isdir(MODEL_DIR):
65
+ console.print(f"[red]Model directory not found at: {MODEL_DIR}[/red]")
66
+ return
67
+
68
+ console.print("\n[bold yellow]Loading model and tokenizer...[/bold yellow]")
69
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
70
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
71
+
72
+ console.print("\n[bold green]Running inference...[/bold green]")
73
+ with console.status("[bold blue]Thinking...[/bold blue]", spinner="dots"):
74
+ pred_class, probabilities = predict_sentiment(review, model, tokenizer)
75
+
76
+ sentiment_label = "Positive" if pred_class == 1 else "Negative"
77
+ pos_prob = probabilities[1]
78
+ neg_prob = probabilities[0]
79
+
80
+ # Create a table for probabilities, using `box.ROUNDED`
81
+ table = Table(title="Sentiment Probabilities", box=box.ROUNDED, expand=False)
82
+ table.add_column("Sentiment", style="bold cyan", justify="center")
83
+ table.add_column("Probability", style="bold magenta", justify="center")
84
+ table.add_row("Positive", f"{pos_prob:.4f}")
85
+ table.add_row("Negative", f"{neg_prob:.4f}")
86
+
87
+ # Create a panel for the final output
88
+ output_panel = Panel(
89
+ table,
90
+ title=f"Predicted Sentiment: [bold green]{sentiment_label}[/bold green]",
91
+ subtitle="Inference Complete",
92
+ border_style="bold magenta"
93
+ )
94
+
95
+ console.rule("[bold magenta]Inference Result[/bold magenta]")
96
+ console.print(output_panel)
97
+ console.rule()
98
+
99
+ if __name__ == "__main__":
100
+ main()