|
from flask import Flask, request, jsonify, render_template, session
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
import torch
|
|
from flask_cors import CORS
|
|
import os
|
|
import sqlite3
|
|
import datetime
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
app = Flask(__name__, static_folder='static', template_folder='templates')
|
|
app.secret_key = os.getenv('SECRET_KEY', os.urandom(24))
|
|
CORS(app)
|
|
|
|
|
|
def get_latest_model_dir(base_dir="chatbot_model"):
|
|
try:
|
|
subdirs = [
|
|
d for d in os.listdir(base_dir)
|
|
if d.startswith("trained_model_") and os.path.isdir(os.path.join(base_dir, d))
|
|
]
|
|
if not subdirs:
|
|
raise Exception("❌ No trained model found in chatbot_model/ directory.")
|
|
latest = sorted(subdirs)[-1]
|
|
return os.path.abspath(os.path.join(base_dir, latest))
|
|
except Exception as e:
|
|
print(f"❌ Error finding model directory: {e}")
|
|
raise
|
|
|
|
|
|
try:
|
|
model_path = get_latest_model_dir()
|
|
print(f"📦 Loading model from: {model_path}")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
model = AutoModelForCausalLM.from_pretrained(model_path)
|
|
except Exception as e:
|
|
print(f"❌ Failed to load model: {e}")
|
|
raise
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
|
|
DB_FILE = "chat_history.db"
|
|
|
|
def insert_chat(session_id, user_msg, bot_msg):
|
|
try:
|
|
with sqlite3.connect(DB_FILE) as conn:
|
|
c = conn.cursor()
|
|
c.execute("INSERT INTO chat (session_id, timestamp, user, bot) VALUES (?, ?, ?, ?)",
|
|
(session_id, datetime.datetime.now().isoformat(), user_msg, bot_msg))
|
|
conn.commit()
|
|
print(f"✅ Inserted chat for session_id={session_id}")
|
|
except Exception as e:
|
|
print(f"❌ Error inserting chat: {e}")
|
|
|
|
def fetch_history(session_id, limit=100):
|
|
try:
|
|
with sqlite3.connect(DB_FILE) as conn:
|
|
c = conn.cursor()
|
|
c.execute("SELECT timestamp, user, bot FROM chat WHERE session_id=? ORDER BY id ASC LIMIT ?",
|
|
(session_id, limit))
|
|
rows = c.fetchall()
|
|
print(f"📜 Fetched {len(rows)} messages for session_id={session_id}")
|
|
return rows
|
|
except Exception as e:
|
|
print(f"❌ Error fetching history for session_id={session_id}: {e}")
|
|
return []
|
|
|
|
def get_all_sessions():
|
|
try:
|
|
with sqlite3.connect(DB_FILE) as conn:
|
|
c = conn.cursor()
|
|
c.execute("SELECT session_id, created_at FROM sessions ORDER BY created_at DESC")
|
|
sessions = c.fetchall()
|
|
print(f"📋 Fetched {len(sessions)} sessions")
|
|
return sessions
|
|
except Exception as e:
|
|
print(f"❌ Error fetching sessions: {e}")
|
|
return []
|
|
|
|
def create_new_session():
|
|
try:
|
|
session_id = datetime.datetime.now().strftime("session_%Y%m%d_%H%M%S")
|
|
with sqlite3.connect(DB_FILE) as conn:
|
|
c = conn.cursor()
|
|
c.execute("INSERT INTO sessions (session_id, created_at) VALUES (?, ?)",
|
|
(session_id, datetime.datetime.now().isoformat()))
|
|
conn.commit()
|
|
print(f"✅ Created new session: {session_id}")
|
|
return session_id
|
|
except Exception as e:
|
|
print(f"❌ Error creating new session: {e}")
|
|
raise
|
|
|
|
def delete_session(session_id):
|
|
try:
|
|
with sqlite3.connect(DB_FILE) as conn:
|
|
c = conn.cursor()
|
|
c.execute("DELETE FROM chat WHERE session_id=?", (session_id,))
|
|
c.execute("DELETE FROM sessions WHERE session_id=?", (session_id,))
|
|
conn.commit()
|
|
print(f"🗑️ Deleted session_id={session_id}")
|
|
|
|
if session.get('chat_id') == session_id:
|
|
session.pop('chat_id', None)
|
|
print(f"🔄 Reset active session_id={session_id}")
|
|
except Exception as e:
|
|
print(f"❌ Error deleting session_id={session_id}: {e}")
|
|
|
|
|
|
@app.route('/')
|
|
def home():
|
|
try:
|
|
if 'chat_id' not in session:
|
|
session['chat_id'] = create_new_session()
|
|
print(f"🏠 Initialized session_id={session['chat_id']} for new user")
|
|
return render_template('index.html')
|
|
except Exception as e:
|
|
print(f"❌ Error in /: {e}")
|
|
return jsonify({"error": "Failed to initialize session"}), 500
|
|
|
|
|
|
@app.route('/new_session', methods=['GET'])
|
|
def new_session():
|
|
try:
|
|
new_sid = create_new_session()
|
|
session['chat_id'] = new_sid
|
|
print(f"➕ New session created: {new_sid}")
|
|
return jsonify({"session_id": new_sid})
|
|
except Exception as e:
|
|
print(f"❌ Error in /new_session: {e}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
@app.route('/chat', methods=['POST'])
|
|
def chat():
|
|
try:
|
|
data = request.get_json()
|
|
user_input = data.get("message", "").strip()
|
|
|
|
if not user_input:
|
|
return jsonify({"error": "Missing or empty 'message' in request"}), 400
|
|
|
|
session_id = session.get("chat_id")
|
|
if not session_id:
|
|
session_id = create_new_session()
|
|
session['chat_id'] = session_id
|
|
print(f"🔄 Created new session_id={session_id} for chat")
|
|
|
|
prompt = f"<|prompter|> {user_input} <|endoftext|><|assistant|>"
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
with torch.no_grad():
|
|
output = model.generate(
|
|
input_ids=inputs["input_ids"],
|
|
attention_mask=inputs["attention_mask"],
|
|
max_new_tokens=800,
|
|
do_sample=True,
|
|
top_k=50,
|
|
top_p=0.95,
|
|
temperature=0.7,
|
|
repetition_penalty=1.15,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
eos_token_id=tokenizer.eos_token_id
|
|
)
|
|
|
|
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
reply = decoded.split("<|assistant|>")[-1].split("<|prompter|>")[0].strip()
|
|
|
|
if not reply:
|
|
reply = "⚠️ Sorry, I couldn't generate a response. Please try again."
|
|
|
|
insert_chat(session_id, user_input, reply)
|
|
|
|
return jsonify({"response": reply})
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error in /chat: {e}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
@app.route('/history', methods=['GET'])
|
|
def history():
|
|
try:
|
|
session_id = session.get("chat_id")
|
|
if not session_id:
|
|
print("⚠️ No session_id found in session")
|
|
return jsonify([])
|
|
rows = fetch_history(session_id)
|
|
return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
|
|
except Exception as e:
|
|
print(f"❌ Error in /history: {e}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
@app.route('/history/<session_id>', methods=['GET'])
|
|
def session_history(session_id):
|
|
try:
|
|
rows = fetch_history(session_id)
|
|
return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
|
|
except Exception as e:
|
|
print(f"❌ Error in /history/{session_id}: {e}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
@app.route('/sessions', methods=['GET'])
|
|
def list_sessions():
|
|
try:
|
|
sessions = get_all_sessions()
|
|
return jsonify([{"session_id": sid, "created_at": ts} for sid, ts in sessions])
|
|
except Exception as e:
|
|
print(f"❌ Error in /sessions: {e}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
@app.route('/sessions/<session_id>', methods=['DELETE'])
|
|
def delete_session_route(session_id):
|
|
try:
|
|
delete_session(session_id)
|
|
return jsonify({"status": "deleted"})
|
|
except Exception as e:
|
|
print(f"❌ Error in /sessions/{session_id}: {e}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
try:
|
|
with sqlite3.connect(DB_FILE) as conn:
|
|
c = conn.cursor()
|
|
c.execute('''
|
|
CREATE TABLE IF NOT EXISTS chat (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id TEXT NOT NULL,
|
|
timestamp TEXT NOT NULL,
|
|
user TEXT,
|
|
bot TEXT
|
|
)
|
|
''')
|
|
c.execute('''
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
session_id TEXT PRIMARY KEY,
|
|
created_at TEXT NOT NULL
|
|
)
|
|
''')
|
|
conn.commit()
|
|
print("✅ Database initialized successfully")
|
|
except Exception as e:
|
|
print(f"❌ Error initializing database: {e}")
|
|
raise
|
|
app.run(debug=True, port=5005) |