
Update README.md to include detailed model information for FinSight AI, a financial advisory chatbot. Added sections on model details, usage examples, training details, limitations, and future improvements. Changed license from Apache-2.0 to MIT and updated language and tags for better categorization.
da00d1b
#!/usr/bin/env python3 | |
""" | |
FinSight AI - Inference script for financial advisory chatbot | |
This script provides a simple way to interact with the model via command line | |
""" | |
import os | |
import argparse | |
import torch | |
from typing import List, Dict | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextStreamer, | |
BitsAndBytesConfig | |
) | |
class FinancialAdvisor: | |
def __init__( | |
self, | |
model_id: str = "zahemen9900/finsight-ai", | |
use_4bit: bool = True, | |
device: str = None | |
): | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {self.device}") | |
# Configure quantization if requested and available | |
if use_4bit and self.device == "cuda": | |
print("Loading model in 4-bit quantization mode") | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
else: | |
print("Loading model in standard mode") | |
bnb_config = None | |
# Load tokenizer and model | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
quantization_config=bnb_config, | |
device_map="auto" if self.device == "cuda" else None, | |
torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32, | |
) | |
if self.device == "cpu": | |
self.model = self.model.to(self.device) | |
self.model.eval() | |
self.conversation_history = [] | |
self.system_message = { | |
"role": "system", | |
"content": ( | |
"You are FinSight AI, a helpful and knowledgeable financial assistant. " | |
"You can provide information and guidance on financial topics, market trends, investment strategies, " | |
"and personal finance management. Always strive to be accurate, informative, and helpful. " | |
"Remember that you cannot provide personalized financial advice that would require knowing a person's " | |
"complete financial situation or future market movements." | |
) | |
} | |
def generate_response( | |
self, | |
prompt: str, | |
temperature: float = 0.7, | |
max_new_tokens: int = 512, | |
stream: bool = True | |
) -> str: | |
"""Generate response from the model""" | |
# Manage conversation history (keep last 5 exchanges) | |
if len(self.conversation_history) > 10: | |
self.conversation_history = self.conversation_history[-10:] | |
# Create messages with history | |
messages = [self.system_message] + self.conversation_history | |
messages.append({"role": "user", "content": prompt}) | |
# Format prompt using chat template | |
formatted_prompt = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Encode the input | |
inputs = self.tokenizer( | |
formatted_prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=4096 | |
).to(self.device) | |
# Setup streamer if requested | |
streamer = TextStreamer( | |
self.tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) if stream else None | |
# Generate response | |
with torch.inference_mode(): | |
output_ids = self.model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=0.95, | |
streamer=streamer, | |
pad_token_id=self.tokenizer.eos_token_id, | |
repetition_penalty=1.1 | |
) | |
# Return the response | |
if not stream: | |
response = self.tokenizer.decode( | |
output_ids[0][inputs.input_ids.shape[1]:], | |
skip_special_tokens=True | |
) | |
print("\nAssistant:", response) | |
else: | |
response = "" # Response was already streamed | |
# Update conversation history | |
self.conversation_history.append({"role": "user", "content": prompt}) | |
self.conversation_history.append({"role": "assistant", "content": response if response else "[Response was streamed]"}) | |
return response | |
def start_chat_loop(self): | |
"""Start an interactive chat session""" | |
print("\nWelcome to FinSight AI - Your Financial Advisory Assistant!") | |
print("Type 'quit', 'exit', or press Ctrl+C to end the conversation.\n") | |
while True: | |
try: | |
user_input = input("\nYou: ").strip() | |
if user_input.lower() in ["quit", "exit", "q"]: | |
break | |
if user_input.lower() == "clear": | |
self.conversation_history = [] | |
print("Conversation history cleared.") | |
continue | |
print("\nAssistant: ", end="", flush=True) | |
self.generate_response(user_input) | |
except KeyboardInterrupt: | |
print("\nExiting chat...") | |
break | |
except Exception as e: | |
print(f"\nError: {e}") | |
continue | |
print("\nThank you for using FinSight AI. Goodbye!") | |
def main(): | |
parser = argparse.ArgumentParser(description="FinSight AI Inference Script") | |
parser.add_argument( | |
"--model_id", | |
type=str, | |
default="zahemen9900/finsight-ai", | |
help="Model ID or path to load" | |
) | |
parser.add_argument( | |
"--no_quantize", | |
action="store_true", | |
help="Disable 4-bit quantization (uses more memory)" | |
) | |
parser.add_argument( | |
"--query", | |
type=str, | |
help="Single query mode: provide a question and get one response" | |
) | |
parser.add_argument( | |
"--temperature", | |
type=float, | |
default=0.7, | |
help="Sampling temperature (higher = more random)" | |
) | |
parser.add_argument( | |
"--max_tokens", | |
type=int, | |
default=512, | |
help="Maximum number of new tokens to generate" | |
) | |
args = parser.parse_args() | |
advisor = FinancialAdvisor( | |
model_id=args.model_id, | |
use_4bit=not args.no_quantize | |
) | |
# Single query mode | |
if args.query: | |
advisor.generate_response( | |
args.query, | |
temperature=args.temperature, | |
max_new_tokens=args.max_tokens | |
) | |
# Interactive chat mode | |
else: | |
advisor.start_chat_loop() | |
if __name__ == "__main__": | |
main() | |