File size: 9,516 Bytes
418c329
 
 
 
 
 
 
 
 
 
 
 
 
 
8f998ed
418c329
 
8f998ed
418c329
8f998ed
 
 
 
 
 
 
 
 
 
 
 
418c329
 
8f998ed
 
 
 
 
 
 
 
418c329
 
 
 
 
 
 
 
 
 
8f998ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418c329
 
 
 
8f998ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418c329
 
 
 
 
 
 
 
 
 
 
8f998ed
 
 
 
 
418c329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f998ed
418c329
 
8f998ed
418c329
 
8f998ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418c329
8f998ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418c329
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
from flask import Flask, request, jsonify, render_template, session
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from flask_cors import CORS
import os
import sqlite3
import datetime
from dotenv import load_dotenv

# ✅ Load environment variables
load_dotenv()

# ✅ Configure Flask app
app = Flask(__name__, static_folder='static', template_folder='templates')
app.secret_key = os.getenv('SECRET_KEY', os.urandom(24))  # Use .env for secret key
CORS(app)

# ✅ Automatically find the latest trained model folder
def get_latest_model_dir(base_dir="chatbot_model"):
    try:
        subdirs = [
            d for d in os.listdir(base_dir)
            if d.startswith("trained_model_") and os.path.isdir(os.path.join(base_dir, d))
        ]
        if not subdirs:
            raise Exception("❌ No trained model found in chatbot_model/ directory.")
        latest = sorted(subdirs)[-1]
        return os.path.abspath(os.path.join(base_dir, latest))
    except Exception as e:
        print(f"❌ Error finding model directory: {e}")
        raise

# ✅ Load model from latest folder
try:
    model_path = get_latest_model_dir()
    print(f"📦 Loading model from: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path)
except Exception as e:
    print(f"❌ Failed to load model: {e}")
    raise

# ✅ Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# ✅ SQLite DB setup
DB_FILE = "chat_history.db"

def insert_chat(session_id, user_msg, bot_msg):
    try:
        with sqlite3.connect(DB_FILE) as conn:
            c = conn.cursor()
            c.execute("INSERT INTO chat (session_id, timestamp, user, bot) VALUES (?, ?, ?, ?)",
                      (session_id, datetime.datetime.now().isoformat(), user_msg, bot_msg))
            conn.commit()
            print(f"✅ Inserted chat for session_id={session_id}")
    except Exception as e:
        print(f"❌ Error inserting chat: {e}")

def fetch_history(session_id, limit=100):
    try:
        with sqlite3.connect(DB_FILE) as conn:
            c = conn.cursor()
            c.execute("SELECT timestamp, user, bot FROM chat WHERE session_id=? ORDER BY id ASC LIMIT ?",
                      (session_id, limit))
            rows = c.fetchall()
            print(f"📜 Fetched {len(rows)} messages for session_id={session_id}")
            return rows
    except Exception as e:
        print(f"❌ Error fetching history for session_id={session_id}: {e}")
        return []

def get_all_sessions():
    try:
        with sqlite3.connect(DB_FILE) as conn:
            c = conn.cursor()
            c.execute("SELECT session_id, created_at FROM sessions ORDER BY created_at DESC")
            sessions = c.fetchall()
            print(f"📋 Fetched {len(sessions)} sessions")
            return sessions
    except Exception as e:
        print(f"❌ Error fetching sessions: {e}")
        return []

def create_new_session():
    try:
        session_id = datetime.datetime.now().strftime("session_%Y%m%d_%H%M%S")
        with sqlite3.connect(DB_FILE) as conn:
            c = conn.cursor()
            c.execute("INSERT INTO sessions (session_id, created_at) VALUES (?, ?)",
                      (session_id, datetime.datetime.now().isoformat()))
            conn.commit()
            print(f"✅ Created new session: {session_id}")
        return session_id
    except Exception as e:
        print(f"❌ Error creating new session: {e}")
        raise

def delete_session(session_id):
    try:
        with sqlite3.connect(DB_FILE) as conn:
            c = conn.cursor()
            c.execute("DELETE FROM chat WHERE session_id=?", (session_id,))
            c.execute("DELETE FROM sessions WHERE session_id=?", (session_id,))
            conn.commit()
            print(f"🗑️ Deleted session_id={session_id}")
            # Reset session if deleted session is active
            if session.get('chat_id') == session_id:
                session.pop('chat_id', None)
                print(f"🔄 Reset active session_id={session_id}")
    except Exception as e:
        print(f"❌ Error deleting session_id={session_id}: {e}")

# ✅ Route for home page
@app.route('/')
def home():
    try:
        if 'chat_id' not in session:
            session['chat_id'] = create_new_session()
            print(f"🏠 Initialized session_id={session['chat_id']} for new user")
        return render_template('index.html')
    except Exception as e:
        print(f"❌ Error in /: {e}")
        return jsonify({"error": "Failed to initialize session"}), 500

# ✅ Create new session and return it
@app.route('/new_session', methods=['GET'])
def new_session():
    try:
        new_sid = create_new_session()
        session['chat_id'] = new_sid
        print(f"➕ New session created: {new_sid}")
        return jsonify({"session_id": new_sid})
    except Exception as e:
        print(f"❌ Error in /new_session: {e}")
        return jsonify({"error": str(e)}), 500

# ✅ POST endpoint for chatbot
@app.route('/chat', methods=['POST'])
def chat():
    try:
        data = request.get_json()
        user_input = data.get("message", "").strip()

        if not user_input:
            return jsonify({"error": "Missing or empty 'message' in request"}), 400

        session_id = session.get("chat_id")
        if not session_id:
            session_id = create_new_session()
            session['chat_id'] = session_id
            print(f"🔄 Created new session_id={session_id} for chat")

        prompt = f"<|prompter|> {user_input} <|endoftext|><|assistant|>"
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        with torch.no_grad():
            output = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=800,
                do_sample=True,
                top_k=50,
                top_p=0.95,
                temperature=0.7,
                repetition_penalty=1.15,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id
            )

        decoded = tokenizer.decode(output[0], skip_special_tokens=True)
        reply = decoded.split("<|assistant|>")[-1].split("<|prompter|>")[0].strip()

        if not reply:
            reply = "⚠️ Sorry, I couldn't generate a response. Please try again."

        insert_chat(session_id, user_input, reply)

        return jsonify({"response": reply})

    except Exception as e:
        print(f"❌ Error in /chat: {e}")
        return jsonify({"error": str(e)}), 500

# ✅ GET current session chat
@app.route('/history', methods=['GET'])
def history():
    try:
        session_id = session.get("chat_id")
        if not session_id:
            print("⚠️ No session_id found in session")
            return jsonify([])
        rows = fetch_history(session_id)
        return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
    except Exception as e:
        print(f"❌ Error in /history: {e}")
        return jsonify({"error": str(e)}), 500

# ✅ GET chat for specific session
@app.route('/history/<session_id>', methods=['GET'])
def session_history(session_id):
    try:
        rows = fetch_history(session_id)
        return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
    except Exception as e:
        print(f"❌ Error in /history/{session_id}: {e}")
        return jsonify({"error": str(e)}), 500

# ✅ GET all sessions
@app.route('/sessions', methods=['GET'])
def list_sessions():
    try:
        sessions = get_all_sessions()
        return jsonify([{"session_id": sid, "created_at": ts} for sid, ts in sessions])
    except Exception as e:
        print(f"❌ Error in /sessions: {e}")
        return jsonify({"error": str(e)}), 500

# ✅ DELETE session
@app.route('/sessions/<session_id>', methods=['DELETE'])
def delete_session_route(session_id):
    try:
        delete_session(session_id)
        return jsonify({"status": "deleted"})
    except Exception as e:
        print(f"❌ Error in /sessions/{session_id}: {e}")
        return jsonify({"error": str(e)}), 500

# ✅ Initialize database and run server
if __name__ == '__main__':
    # Ensure database is initialized without merging init_db.py
    try:
        with sqlite3.connect(DB_FILE) as conn:
            c = conn.cursor()
            c.execute('''

                CREATE TABLE IF NOT EXISTS chat (

                    id INTEGER PRIMARY KEY AUTOINCREMENT,

                    session_id TEXT NOT NULL,

                    timestamp TEXT NOT NULL,

                    user TEXT,

                    bot TEXT

                )

            ''')
            c.execute('''

                CREATE TABLE IF NOT EXISTS sessions (

                    session_id TEXT PRIMARY KEY,

                    created_at TEXT NOT NULL

                )

            ''')
            conn.commit()
            print("✅ Database initialized successfully")
    except Exception as e:
        print(f"❌ Error initializing database: {e}")
        raise
    app.run(debug=True, port=5005)