Jasleen05's picture
Update app.py
8f998ed verified
from flask import Flask, request, jsonify, render_template, session
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from flask_cors import CORS
import os
import sqlite3
import datetime
from dotenv import load_dotenv
# ✅ Load environment variables
load_dotenv()
# ✅ Configure Flask app
app = Flask(__name__, static_folder='static', template_folder='templates')
app.secret_key = os.getenv('SECRET_KEY', os.urandom(24)) # Use .env for secret key
CORS(app)
# ✅ Automatically find the latest trained model folder
def get_latest_model_dir(base_dir="chatbot_model"):
try:
subdirs = [
d for d in os.listdir(base_dir)
if d.startswith("trained_model_") and os.path.isdir(os.path.join(base_dir, d))
]
if not subdirs:
raise Exception("❌ No trained model found in chatbot_model/ directory.")
latest = sorted(subdirs)[-1]
return os.path.abspath(os.path.join(base_dir, latest))
except Exception as e:
print(f"❌ Error finding model directory: {e}")
raise
# ✅ Load model from latest folder
try:
model_path = get_latest_model_dir()
print(f"📦 Loading model from: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
except Exception as e:
print(f"❌ Failed to load model: {e}")
raise
# ✅ Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# ✅ SQLite DB setup
DB_FILE = "chat_history.db"
def insert_chat(session_id, user_msg, bot_msg):
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("INSERT INTO chat (session_id, timestamp, user, bot) VALUES (?, ?, ?, ?)",
(session_id, datetime.datetime.now().isoformat(), user_msg, bot_msg))
conn.commit()
print(f"✅ Inserted chat for session_id={session_id}")
except Exception as e:
print(f"❌ Error inserting chat: {e}")
def fetch_history(session_id, limit=100):
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("SELECT timestamp, user, bot FROM chat WHERE session_id=? ORDER BY id ASC LIMIT ?",
(session_id, limit))
rows = c.fetchall()
print(f"📜 Fetched {len(rows)} messages for session_id={session_id}")
return rows
except Exception as e:
print(f"❌ Error fetching history for session_id={session_id}: {e}")
return []
def get_all_sessions():
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("SELECT session_id, created_at FROM sessions ORDER BY created_at DESC")
sessions = c.fetchall()
print(f"📋 Fetched {len(sessions)} sessions")
return sessions
except Exception as e:
print(f"❌ Error fetching sessions: {e}")
return []
def create_new_session():
try:
session_id = datetime.datetime.now().strftime("session_%Y%m%d_%H%M%S")
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("INSERT INTO sessions (session_id, created_at) VALUES (?, ?)",
(session_id, datetime.datetime.now().isoformat()))
conn.commit()
print(f"✅ Created new session: {session_id}")
return session_id
except Exception as e:
print(f"❌ Error creating new session: {e}")
raise
def delete_session(session_id):
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute("DELETE FROM chat WHERE session_id=?", (session_id,))
c.execute("DELETE FROM sessions WHERE session_id=?", (session_id,))
conn.commit()
print(f"🗑️ Deleted session_id={session_id}")
# Reset session if deleted session is active
if session.get('chat_id') == session_id:
session.pop('chat_id', None)
print(f"🔄 Reset active session_id={session_id}")
except Exception as e:
print(f"❌ Error deleting session_id={session_id}: {e}")
# ✅ Route for home page
@app.route('/')
def home():
try:
if 'chat_id' not in session:
session['chat_id'] = create_new_session()
print(f"🏠 Initialized session_id={session['chat_id']} for new user")
return render_template('index.html')
except Exception as e:
print(f"❌ Error in /: {e}")
return jsonify({"error": "Failed to initialize session"}), 500
# ✅ Create new session and return it
@app.route('/new_session', methods=['GET'])
def new_session():
try:
new_sid = create_new_session()
session['chat_id'] = new_sid
print(f"➕ New session created: {new_sid}")
return jsonify({"session_id": new_sid})
except Exception as e:
print(f"❌ Error in /new_session: {e}")
return jsonify({"error": str(e)}), 500
# ✅ POST endpoint for chatbot
@app.route('/chat', methods=['POST'])
def chat():
try:
data = request.get_json()
user_input = data.get("message", "").strip()
if not user_input:
return jsonify({"error": "Missing or empty 'message' in request"}), 400
session_id = session.get("chat_id")
if not session_id:
session_id = create_new_session()
session['chat_id'] = session_id
print(f"🔄 Created new session_id={session_id} for chat")
prompt = f"<|prompter|> {user_input} <|endoftext|><|assistant|>"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=800,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
reply = decoded.split("<|assistant|>")[-1].split("<|prompter|>")[0].strip()
if not reply:
reply = "⚠️ Sorry, I couldn't generate a response. Please try again."
insert_chat(session_id, user_input, reply)
return jsonify({"response": reply})
except Exception as e:
print(f"❌ Error in /chat: {e}")
return jsonify({"error": str(e)}), 500
# ✅ GET current session chat
@app.route('/history', methods=['GET'])
def history():
try:
session_id = session.get("chat_id")
if not session_id:
print("⚠️ No session_id found in session")
return jsonify([])
rows = fetch_history(session_id)
return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
except Exception as e:
print(f"❌ Error in /history: {e}")
return jsonify({"error": str(e)}), 500
# ✅ GET chat for specific session
@app.route('/history/<session_id>', methods=['GET'])
def session_history(session_id):
try:
rows = fetch_history(session_id)
return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
except Exception as e:
print(f"❌ Error in /history/{session_id}: {e}")
return jsonify({"error": str(e)}), 500
# ✅ GET all sessions
@app.route('/sessions', methods=['GET'])
def list_sessions():
try:
sessions = get_all_sessions()
return jsonify([{"session_id": sid, "created_at": ts} for sid, ts in sessions])
except Exception as e:
print(f"❌ Error in /sessions: {e}")
return jsonify({"error": str(e)}), 500
# ✅ DELETE session
@app.route('/sessions/<session_id>', methods=['DELETE'])
def delete_session_route(session_id):
try:
delete_session(session_id)
return jsonify({"status": "deleted"})
except Exception as e:
print(f"❌ Error in /sessions/{session_id}: {e}")
return jsonify({"error": str(e)}), 500
# ✅ Initialize database and run server
if __name__ == '__main__':
# Ensure database is initialized without merging init_db.py
try:
with sqlite3.connect(DB_FILE) as conn:
c = conn.cursor()
c.execute('''
CREATE TABLE IF NOT EXISTS chat (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
timestamp TEXT NOT NULL,
user TEXT,
bot TEXT
)
''')
c.execute('''
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
created_at TEXT NOT NULL
)
''')
conn.commit()
print("✅ Database initialized successfully")
except Exception as e:
print(f"❌ Error initializing database: {e}")
raise
app.run(debug=True, port=5005)