|
|
|
import pickle |
|
import json |
|
import numpy as np |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
class BuildingMaterialsInference: |
|
def __init__(self): |
|
|
|
with open('building_materials_model.pkl', 'rb') as f: |
|
model_data = pickle.load(f) |
|
|
|
self.knowledge_base = model_data['knowledge_base'] |
|
self.vectorizer = model_data['vectorizer'] |
|
self.query_vectors = model_data['query_vectors'] |
|
self.queries = model_data['queries'] |
|
self.responses = model_data['responses'] |
|
|
|
def __call__(self, inputs): |
|
"""Hugging Face inference API compatible method""" |
|
if isinstance(inputs, dict): |
|
query = inputs.get('query', inputs.get('inputs', '')) |
|
else: |
|
query = str(inputs) |
|
|
|
return self.search_materials(query) |
|
|
|
def search_materials(self, user_query, top_k=3): |
|
"""Search for building materials based on user query""" |
|
user_query_lower = user_query.lower() |
|
|
|
|
|
if user_query_lower in self.knowledge_base: |
|
response_data = json.loads(self.knowledge_base[user_query_lower]) |
|
return { |
|
"results": response_data.get('results', []), |
|
"confidence": 1.0, |
|
"query_matched": user_query |
|
} |
|
|
|
|
|
user_vector = self.vectorizer.transform([user_query_lower]) |
|
similarities = cosine_similarity(user_vector, self.query_vectors)[0] |
|
|
|
|
|
top_indices = np.argsort(similarities)[::-1][:top_k] |
|
|
|
|
|
all_results = [] |
|
seen_suppliers = set() |
|
|
|
for idx in top_indices: |
|
if similarities[idx] > 0.1: |
|
response_data = json.loads(self.responses[idx]) |
|
for result in response_data.get('results', []): |
|
supplier = result.get('supplier', '') |
|
if supplier not in seen_suppliers: |
|
all_results.append(result) |
|
seen_suppliers.add(supplier) |
|
|
|
|
|
def extract_price(result): |
|
price_str = result.get('price', '£0') |
|
try: |
|
return float(price_str.replace('£', '').replace(',', '')) |
|
except: |
|
return 999999 |
|
|
|
all_results.sort(key=extract_price) |
|
|
|
return { |
|
"results": all_results[:5], |
|
"confidence": float(max(similarities)) if len(similarities) > 0 else 0.0, |
|
"query_matched": user_query |
|
} |
|
|
|
|
|
model = BuildingMaterialsInference() |
|
|