File size: 9,516 Bytes
418c329 8f998ed 418c329 8f998ed 418c329 8f998ed 418c329 8f998ed 418c329 8f998ed 418c329 8f998ed 418c329 8f998ed 418c329 8f998ed 418c329 8f998ed 418c329 8f998ed 418c329 8f998ed 418c329 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
from flask import Flask, request, jsonify, render_template, session
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from flask_cors import CORS
import os
import sqlite3
import datetime
from dotenv import load_dotenv
# ✅ Load environment variables
load_dotenv()
# ✅ Configure Flask app
app = Flask(__name__, static_folder='static', template_folder='templates')
app.secret_key = os.getenv('SECRET_KEY', os.urandom(24)) # Use .env for secret key
CORS(app)
# ✅ Automatically find the latest trained model folder
def get_latest_model_dir(base_dir="chatbot_model"):
try:
subdirs = [
d for d in os.listdir(base_dir)
if d.startswith("trained_model_") and os.path.isdir(os.path.join(base_dir, d))
]
if not subdirs:
raise Exception("❌ No trained model found in chatbot_model/ directory.")
latest = sorted(subdirs)[-1]
return os.path.abspath(os.path.join(base_dir, latest))
except Exception as e:
print(f"❌ Error finding model directory: {e}")
raise
# ✅ Load model from latest folder
try:
model_path = get_latest_model_dir()
print(f"📦 Loading model from: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
except Exception as e:
print(f"❌ Failed to load model: {e}")
raise
# ✅ Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# ✅ SQLite DB setup
DB_FILE = "chat_history.db"
def insert_chat(session_id, user_msg, bot_msg):
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("INSERT INTO chat (session_id, timestamp, user, bot) VALUES (?, ?, ?, ?)",
(session_id, datetime.datetime.now().isoformat(), user_msg, bot_msg))
conn.commit()
print(f"✅ Inserted chat for session_id={session_id}")
except Exception as e:
print(f"❌ Error inserting chat: {e}")
def fetch_history(session_id, limit=100):
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("SELECT timestamp, user, bot FROM chat WHERE session_id=? ORDER BY id ASC LIMIT ?",
(session_id, limit))
rows = c.fetchall()
print(f"📜 Fetched {len(rows)} messages for session_id={session_id}")
return rows
except Exception as e:
print(f"❌ Error fetching history for session_id={session_id}: {e}")
return []
def get_all_sessions():
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("SELECT session_id, created_at FROM sessions ORDER BY created_at DESC")
sessions = c.fetchall()
print(f"📋 Fetched {len(sessions)} sessions")
return sessions
except Exception as e:
print(f"❌ Error fetching sessions: {e}")
return []
def create_new_session():
try:
session_id = datetime.datetime.now().strftime("session_%Y%m%d_%H%M%S")
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("INSERT INTO sessions (session_id, created_at) VALUES (?, ?)",
(session_id, datetime.datetime.now().isoformat()))
conn.commit()
print(f"✅ Created new session: {session_id}")
return session_id
except Exception as e:
print(f"❌ Error creating new session: {e}")
raise
def delete_session(session_id):
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("DELETE FROM chat WHERE session_id=?", (session_id,))
c.execute("DELETE FROM sessions WHERE session_id=?", (session_id,))
conn.commit()
print(f"🗑️ Deleted session_id={session_id}")
# Reset session if deleted session is active
if session.get('chat_id') == session_id:
session.pop('chat_id', None)
print(f"🔄 Reset active session_id={session_id}")
except Exception as e:
print(f"❌ Error deleting session_id={session_id}: {e}")
# ✅ Route for home page
@app.route('/')
def home():
try:
if 'chat_id' not in session:
session['chat_id'] = create_new_session()
print(f"🏠 Initialized session_id={session['chat_id']} for new user")
return render_template('index.html')
except Exception as e:
print(f"❌ Error in /: {e}")
return jsonify({"error": "Failed to initialize session"}), 500
# ✅ Create new session and return it
@app.route('/new_session', methods=['GET'])
def new_session():
try:
new_sid = create_new_session()
session['chat_id'] = new_sid
print(f"➕ New session created: {new_sid}")
return jsonify({"session_id": new_sid})
except Exception as e:
print(f"❌ Error in /new_session: {e}")
return jsonify({"error": str(e)}), 500
# ✅ POST endpoint for chatbot
@app.route('/chat', methods=['POST'])
def chat():
try:
data = request.get_json()
user_input = data.get("message", "").strip()
if not user_input:
return jsonify({"error": "Missing or empty 'message' in request"}), 400
session_id = session.get("chat_id")
if not session_id:
session_id = create_new_session()
session['chat_id'] = session_id
print(f"🔄 Created new session_id={session_id} for chat")
prompt = f"<|prompter|> {user_input} <|endoftext|><|assistant|>"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=800,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
reply = decoded.split("<|assistant|>")[-1].split("<|prompter|>")[0].strip()
if not reply:
reply = "⚠️ Sorry, I couldn't generate a response. Please try again."
insert_chat(session_id, user_input, reply)
return jsonify({"response": reply})
except Exception as e:
print(f"❌ Error in /chat: {e}")
return jsonify({"error": str(e)}), 500
# ✅ GET current session chat
@app.route('/history', methods=['GET'])
def history():
try:
session_id = session.get("chat_id")
if not session_id:
print("⚠️ No session_id found in session")
return jsonify([])
rows = fetch_history(session_id)
return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
except Exception as e:
print(f"❌ Error in /history: {e}")
return jsonify({"error": str(e)}), 500
# ✅ GET chat for specific session
@app.route('/history/<session_id>', methods=['GET'])
def session_history(session_id):
try:
rows = fetch_history(session_id)
return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
except Exception as e:
print(f"❌ Error in /history/{session_id}: {e}")
return jsonify({"error": str(e)}), 500
# ✅ GET all sessions
@app.route('/sessions', methods=['GET'])
def list_sessions():
try:
sessions = get_all_sessions()
return jsonify([{"session_id": sid, "created_at": ts} for sid, ts in sessions])
except Exception as e:
print(f"❌ Error in /sessions: {e}")
return jsonify({"error": str(e)}), 500
# ✅ DELETE session
@app.route('/sessions/<session_id>', methods=['DELETE'])
def delete_session_route(session_id):
try:
delete_session(session_id)
return jsonify({"status": "deleted"})
except Exception as e:
print(f"❌ Error in /sessions/{session_id}: {e}")
return jsonify({"error": str(e)}), 500
# ✅ Initialize database and run server
if __name__ == '__main__':
# Ensure database is initialized without merging init_db.py
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute('''
CREATE TABLE IF NOT EXISTS chat (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
timestamp TEXT NOT NULL,
user TEXT,
bot TEXT
)
''')
c.execute('''
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
created_at TEXT NOT NULL
)
''')
conn.commit()
print("✅ Database initialized successfully")
except Exception as e:
print(f"❌ Error initializing database: {e}")
raise
app.run(debug=True, port=5005) |