finsight-ai / inference.py
zahemen9900's picture
Update README.md to include detailed model information for FinSight AI, a financial advisory chatbot. Added sections on model details, usage examples, training details, limitations, and future improvements. Changed license from Apache-2.0 to MIT and updated language and tags for better categorization.
da00d1b
raw
history blame
7.11 kB
#!/usr/bin/env python3
"""
FinSight AI - Inference script for financial advisory chatbot
This script provides a simple way to interact with the model via command line
"""
import os
import argparse
import torch
from typing import List, Dict
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextStreamer,
BitsAndBytesConfig
)
class FinancialAdvisor:
def __init__(
self,
model_id: str = "zahemen9900/finsight-ai",
use_4bit: bool = True,
device: str = None
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Configure quantization if requested and available
if use_4bit and self.device == "cuda":
print("Loading model in 4-bit quantization mode")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
else:
print("Loading model in standard mode")
bnb_config = None
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto" if self.device == "cuda" else None,
torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
)
if self.device == "cpu":
self.model = self.model.to(self.device)
self.model.eval()
self.conversation_history = []
self.system_message = {
"role": "system",
"content": (
"You are FinSight AI, a helpful and knowledgeable financial assistant. "
"You can provide information and guidance on financial topics, market trends, investment strategies, "
"and personal finance management. Always strive to be accurate, informative, and helpful. "
"Remember that you cannot provide personalized financial advice that would require knowing a person's "
"complete financial situation or future market movements."
)
}
def generate_response(
self,
prompt: str,
temperature: float = 0.7,
max_new_tokens: int = 512,
stream: bool = True
) -> str:
"""Generate response from the model"""
# Manage conversation history (keep last 5 exchanges)
if len(self.conversation_history) > 10:
self.conversation_history = self.conversation_history[-10:]
# Create messages with history
messages = [self.system_message] + self.conversation_history
messages.append({"role": "user", "content": prompt})
# Format prompt using chat template
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Encode the input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=4096
).to(self.device)
# Setup streamer if requested
streamer = TextStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True
) if stream else None
# Generate response
with torch.inference_mode():
output_ids = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.95,
streamer=streamer,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Return the response
if not stream:
response = self.tokenizer.decode(
output_ids[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
print("\nAssistant:", response)
else:
response = "" # Response was already streamed
# Update conversation history
self.conversation_history.append({"role": "user", "content": prompt})
self.conversation_history.append({"role": "assistant", "content": response if response else "[Response was streamed]"})
return response
def start_chat_loop(self):
"""Start an interactive chat session"""
print("\nWelcome to FinSight AI - Your Financial Advisory Assistant!")
print("Type 'quit', 'exit', or press Ctrl+C to end the conversation.\n")
while True:
try:
user_input = input("\nYou: ").strip()
if user_input.lower() in ["quit", "exit", "q"]:
break
if user_input.lower() == "clear":
self.conversation_history = []
print("Conversation history cleared.")
continue
print("\nAssistant: ", end="", flush=True)
self.generate_response(user_input)
except KeyboardInterrupt:
print("\nExiting chat...")
break
except Exception as e:
print(f"\nError: {e}")
continue
print("\nThank you for using FinSight AI. Goodbye!")
def main():
parser = argparse.ArgumentParser(description="FinSight AI Inference Script")
parser.add_argument(
"--model_id",
type=str,
default="zahemen9900/finsight-ai",
help="Model ID or path to load"
)
parser.add_argument(
"--no_quantize",
action="store_true",
help="Disable 4-bit quantization (uses more memory)"
)
parser.add_argument(
"--query",
type=str,
help="Single query mode: provide a question and get one response"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature (higher = more random)"
)
parser.add_argument(
"--max_tokens",
type=int,
default=512,
help="Maximum number of new tokens to generate"
)
args = parser.parse_args()
advisor = FinancialAdvisor(
model_id=args.model_id,
use_4bit=not args.no_quantize
)
# Single query mode
if args.query:
advisor.generate_response(
args.query,
temperature=args.temperature,
max_new_tokens=args.max_tokens
)
# Interactive chat mode
else:
advisor.start_chat_loop()
if __name__ == "__main__":
main()