|
import os |
|
import asyncio |
|
from fastapi import FastAPI, HTTPException, Depends |
|
from fastapi.security import OAuth2PasswordBearer |
|
from pydantic import BaseModel |
|
from langchain.chains import RetrievalQA |
|
from langchain.prompts import PromptTemplate |
|
from langchain.vectorstores import FAISS |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.document_loaders import TextLoader, DataFrameLoader |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.document_loaders import TextLoader, DataFrameLoader |
|
from langchain_community.llms import HuggingFacePipeline |
|
from huggingface_hub import login |
|
|
|
|
|
login(token=os.getenv("HUGGINGFACE_HUB_TOKEN")) |
|
|
|
import pandas as pd |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" |
|
os.environ["LANGCHAIN_API_KEY"] = "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" |
|
os.environ["LANGCHAIN_PROJECT"] = "yotta-vm-chatbot" |
|
|
|
|
|
def load_llama2_chat_model(): |
|
model_name = "meta-llama/Llama-2-7b-chat-hf" |
|
print("Loading Llama 2 Chat model...") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="auto", |
|
torch_dtype="float16" |
|
) |
|
pipeline_model = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
return HuggingFacePipeline(pipeline=pipeline_model) |
|
|
|
llama_model = load_llama2_chat_model() |
|
|
|
|
|
def fetch_excel_data(file_path): |
|
""" |
|
Fetch data from a local Excel file and prepare documents for vector store. |
|
""" |
|
print("Loading data from Excel file...") |
|
df = pd.read_excel(file_path) |
|
loader = DataFrameLoader(df, page_content_column="Description", metadata_columns=["Title"]) |
|
documents = loader.load() |
|
return documents |
|
|
|
|
|
def update_vector_store(file_path): |
|
""" |
|
Load data from Excel and update the vector store. |
|
""" |
|
documents = fetch_excel_data(file_path) |
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
vector_store = FAISS.from_documents(documents, embeddings) |
|
return vector_store |
|
|
|
|
|
excel_file_path = "certificate_details_chatbot_2.xlsx" |
|
vector_store = update_vector_store(excel_file_path) |
|
|
|
|
|
retriever = vector_store.as_retriever() |
|
memory = ConversationBufferMemory() |
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llama_model, |
|
retriever=retriever, |
|
memory=memory, |
|
return_source_documents=True, |
|
chain_type_kwargs={ |
|
"prompt": PromptTemplate( |
|
input_variables=["context", "question"], |
|
template="Use the following context to answer the question:\n{context}\n\nQuestion: {question}\nAnswer:" |
|
) |
|
} |
|
) |
|
|
|
|
|
|
|
async def periodic_sync(interval: int = 3600): |
|
""" |
|
Periodically fetch new data from the Excel file and update the vector store. |
|
""" |
|
while True: |
|
try: |
|
update_vector_store(excel_file_path) |
|
print("Vector store updated with the latest Excel data.") |
|
except Exception as e: |
|
print(f"Error updating vector store: {str(e)}") |
|
await asyncio.sleep(interval) |
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
class QueryRequest(BaseModel): |
|
query: str |
|
|
|
@app.get("/") |
|
def root(): |
|
return {"message": "Welcome to the Excel-based Chatbot with RAG and Llama Integration!"} |
|
|
|
@app.post("/query") |
|
async def query(request: QueryRequest): |
|
try: |
|
response = qa_chain({"query": request.query}) |
|
return { |
|
"answer": response['result'], |
|
"source_documents": [ |
|
{"page_content": doc.page_content, "metadata": doc.metadata} |
|
for doc in response["source_documents"] |
|
] |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error processing the query: {str(e)}") |
|
|
|
@app.post("/token") |
|
async def token(): |
|
return {"access_token": "secure_token_123", "token_type": "bearer"} |
|
|
|
@app.on_event("startup") |
|
async def start_background_tasks(): |
|
asyncio.create_task(periodic_sync()) |
|
|
|
|
|
@app.get("/interface") |
|
def interface(): |
|
"""Return a simple HTML interface for interacting with the chatbot.""" |
|
return { |
|
"html": """ |
|
<html> |
|
<head><title>Chatbot Interface</title></head> |
|
<body> |
|
<h1>Chat with the Bot</h1> |
|
<form method="post" action="/query"> |
|
<label for="query">Enter your query:</label><br> |
|
<input type="text" id="query" name="query"/><br><br> |
|
<button type="submit">Submit</button> |
|
</form> |
|
</body> |
|
</html> |
|
""" |
|
} |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
print("Starting the chatbot server...") |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|