File size: 7,105 Bytes
da00d1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#!/usr/bin/env python3
"""
FinSight AI - Inference script for financial advisory chatbot
This script provides a simple way to interact with the model via command line
"""

import os
import argparse
import torch
from typing import List, Dict
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    TextStreamer,
    BitsAndBytesConfig
)

class FinancialAdvisor:
    def __init__(
        self,
        model_id: str = "zahemen9900/finsight-ai",
        use_4bit: bool = True,
        device: str = None
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Configure quantization if requested and available
        if use_4bit and self.device == "cuda":
            print("Loading model in 4-bit quantization mode")
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16
            )
        else:
            print("Loading model in standard mode")
            bnb_config = None
        
        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config,
            device_map="auto" if self.device == "cuda" else None,
            torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
        )
        
        if self.device == "cpu":
            self.model = self.model.to(self.device)
            
        self.model.eval()
        self.conversation_history = []
        self.system_message = {
            "role": "system",
            "content": (
                "You are FinSight AI, a helpful and knowledgeable financial assistant. "
                "You can provide information and guidance on financial topics, market trends, investment strategies, "
                "and personal finance management. Always strive to be accurate, informative, and helpful. "
                "Remember that you cannot provide personalized financial advice that would require knowing a person's "
                "complete financial situation or future market movements."
            )
        }
    
    def generate_response(
        self,
        prompt: str,
        temperature: float = 0.7,
        max_new_tokens: int = 512,
        stream: bool = True
    ) -> str:
        """Generate response from the model"""
        # Manage conversation history (keep last 5 exchanges)
        if len(self.conversation_history) > 10:
            self.conversation_history = self.conversation_history[-10:]
            
        # Create messages with history
        messages = [self.system_message] + self.conversation_history
        messages.append({"role": "user", "content": prompt})
        
        # Format prompt using chat template
        formatted_prompt = self.tokenizer.apply_chat_template(
            messages, 
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Encode the input
        inputs = self.tokenizer(
            formatted_prompt, 
            return_tensors="pt",
            truncation=True,
            max_length=4096
        ).to(self.device)
        
        # Setup streamer if requested
        streamer = TextStreamer(
            self.tokenizer, 
            skip_prompt=True,
            skip_special_tokens=True
        ) if stream else None
        
        # Generate response
        with torch.inference_mode():
            output_ids = self.model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=0.95,
                streamer=streamer,
                pad_token_id=self.tokenizer.eos_token_id,
                repetition_penalty=1.1
            )
        
        # Return the response
        if not stream:
            response = self.tokenizer.decode(
                output_ids[0][inputs.input_ids.shape[1]:], 
                skip_special_tokens=True
            )
            print("\nAssistant:", response)
        else:
            response = ""  # Response was already streamed
            
        # Update conversation history
        self.conversation_history.append({"role": "user", "content": prompt})
        self.conversation_history.append({"role": "assistant", "content": response if response else "[Response was streamed]"})
        
        return response
        
    def start_chat_loop(self):
        """Start an interactive chat session"""
        print("\nWelcome to FinSight AI - Your Financial Advisory Assistant!")
        print("Type 'quit', 'exit', or press Ctrl+C to end the conversation.\n")
        
        while True:
            try:
                user_input = input("\nYou: ").strip()
                if user_input.lower() in ["quit", "exit", "q"]:
                    break
                    
                if user_input.lower() == "clear":
                    self.conversation_history = []
                    print("Conversation history cleared.")
                    continue
                
                print("\nAssistant: ", end="", flush=True)
                self.generate_response(user_input)
                
            except KeyboardInterrupt:
                print("\nExiting chat...")
                break
            except Exception as e:
                print(f"\nError: {e}")
                continue
                
        print("\nThank you for using FinSight AI. Goodbye!")

def main():
    parser = argparse.ArgumentParser(description="FinSight AI Inference Script")
    parser.add_argument(
        "--model_id", 
        type=str, 
        default="zahemen9900/finsight-ai",
        help="Model ID or path to load"
    )
    parser.add_argument(
        "--no_quantize", 
        action="store_true",
        help="Disable 4-bit quantization (uses more memory)"
    )
    parser.add_argument(
        "--query", 
        type=str,
        help="Single query mode: provide a question and get one response"
    )
    parser.add_argument(
        "--temperature", 
        type=float, 
        default=0.7,
        help="Sampling temperature (higher = more random)"
    )
    parser.add_argument(
        "--max_tokens", 
        type=int, 
        default=512,
        help="Maximum number of new tokens to generate"
    )
    
    args = parser.parse_args()
    
    advisor = FinancialAdvisor(
        model_id=args.model_id, 
        use_4bit=not args.no_quantize
    )
    
    # Single query mode
    if args.query:
        advisor.generate_response(
            args.query, 
            temperature=args.temperature, 
            max_new_tokens=args.max_tokens
        )
    # Interactive chat mode
    else:
        advisor.start_chat_loop()

if __name__ == "__main__":
    main()