dejanseo commited on
Commit
a85f7db
·
verified ·
1 Parent(s): 3e77509

Upload demo.py

Browse files
Files changed (1) hide show
  1. 43/demo.py +266 -0
43/demo.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import sentencepiece as spm
4
+ import numpy as np
5
+ from scipy.spatial.distance import cosine
6
+ import pandas as pd
7
+ from openTSNE import TSNE
8
+ import plotly.express as px
9
+ import plotly.graph_objects as go
10
+
11
+ # Set Streamlit layout to wide mode and remove padding
12
+ st.set_page_config(layout="wide")
13
+
14
+ # Remove default padding
15
+ st.markdown("""
16
+ <style>
17
+ .block-container {
18
+ padding-top: 1rem;
19
+ padding-bottom: 0rem;
20
+ padding-left: 1rem;
21
+ padding-right: 1rem;
22
+ }
23
+ </style>
24
+ """, unsafe_allow_html=True)
25
+
26
+ # Load the TFLite model and SentencePiece model
27
+ tflite_model_path = "model.tflite"
28
+ spm_model_path = "sentencepiece.model"
29
+
30
+ sp = spm.SentencePieceProcessor()
31
+ sp.load(spm_model_path)
32
+
33
+ interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
34
+ interpreter.allocate_tensors()
35
+
36
+ input_details = interpreter.get_input_details()
37
+ output_details = interpreter.get_output_details()
38
+ required_input_length = 64 # Fixed length of 64 tokens
39
+
40
+ # Function to preprocess text input
41
+ def preprocess_text(text, sp, required_length):
42
+ input_ids = sp.encode(text, out_type=int)
43
+ input_ids = input_ids[:required_length] + [0] * (required_length - len(input_ids))
44
+ return np.array(input_ids, dtype=np.int32).reshape(1, -1)
45
+
46
+ # Function to generate embeddings
47
+ def generate_embeddings(text):
48
+ input_data = preprocess_text(text, sp, required_input_length)
49
+ interpreter.set_tensor(input_details[0]['index'], input_data)
50
+ interpreter.invoke()
51
+ embedding = interpreter.get_tensor(output_details[0]['index'])
52
+ return embedding.flatten()
53
+
54
+ # Function to calculate similarity scores between sentences
55
+ def calculate_similarity(embedding1, embedding2):
56
+ return 1 - cosine(embedding1, embedding2)
57
+
58
+ # Predefined sentence sets
59
+ preset_sentences_a = [
60
+ "Dan Petrovic predicted conversational search in 2013.",
61
+ "Understanding user intent is key to effective SEO.",
62
+ "Dejan SEO has been a leader in data-driven SEO.",
63
+ "Machine learning is transforming search engines.",
64
+ "The future of search is AI-driven and personalized.",
65
+ "Search algorithms are evolving to better match user intent.",
66
+ "AI technologies enhance digital marketing strategies."
67
+ ]
68
+
69
+ preset_sentences_b = [
70
+ "Advances in machine learning reshape how search engines operate.",
71
+ "Personalized content is becoming more prevalent with AI.",
72
+ "Customer behavior insights are crucial for marketing strategies.",
73
+ "Dan Petrovic anticipated the rise of chat-based search interactions.",
74
+ "Dejan SEO is recognized for innovative SEO research and analysis.",
75
+ "Quantum computing is advancing rapidly in the tech world.",
76
+ "Studying user behavior can improve the effectiveness of online ads."
77
+ ]
78
+
79
+ # Initialize session state for input fields if not already set
80
+ if "input_text_a" not in st.session_state:
81
+ st.session_state["input_text_a"] = "\n".join(preset_sentences_a)
82
+ if "input_text_b" not in st.session_state:
83
+ st.session_state["input_text_b"] = "\n".join(preset_sentences_b)
84
+
85
+ # Clear button to reset text areas
86
+ if st.button("Clear Fields"):
87
+ st.session_state["input_text_a"] = ""
88
+ st.session_state["input_text_b"] = ""
89
+
90
+ # Side-by-side layout for Set A and Set B inputs
91
+ col1, col2 = st.columns(2)
92
+
93
+ with col1:
94
+ st.subheader("Set A Sentences")
95
+ input_text_a = st.text_area("Set A", value=st.session_state["input_text_a"], height=200)
96
+
97
+ with col2:
98
+ st.subheader("Set B Sentences")
99
+ input_text_b = st.text_area("Set B", value=st.session_state["input_text_b"], height=200)
100
+
101
+ # Slider to control t-SNE iteration steps
102
+ iterations = st.slider("Number of t-SNE Iterations (Higher values = more refined clusters)", 250, 1000, step=250)
103
+
104
+ # Similarity threshold slider
105
+ similarity_threshold = st.slider("Similarity Threshold", 0.0, 1.0, 0.5, 0.05)
106
+
107
+ # Submit button
108
+ if st.button("Calculate Similarity"):
109
+ sentences_a = [line.strip() for line in input_text_a.split("\n") if line.strip()]
110
+ sentences_b = [line.strip() for line in input_text_b.split("\n") if line.strip()]
111
+
112
+ if len(sentences_a) > 0 and len(sentences_b) > 0:
113
+ # Generate embeddings for both sets
114
+ embeddings_a = [generate_embeddings(sentence) for sentence in sentences_a]
115
+ embeddings_b = [generate_embeddings(sentence) for sentence in sentences_b]
116
+
117
+ # Combine sentences and embeddings for both sets
118
+ all_sentences = sentences_a + sentences_b
119
+ all_embeddings = np.array(embeddings_a + embeddings_b)
120
+ labels = ["Set A"] * len(sentences_a) + ["Set B"] * len(sentences_b)
121
+
122
+ # Calculate similarity matrix
123
+ similarity_matrix = np.zeros((len(sentences_a), len(sentences_b)))
124
+ for i, emb_a in enumerate(embeddings_a):
125
+ for j, emb_b in enumerate(embeddings_b):
126
+ similarity_matrix[i, j] = calculate_similarity(emb_a, emb_b)
127
+
128
+ # Greedy approach to find best matches above the threshold
129
+ used_a = set()
130
+ used_b = set()
131
+ matches = []
132
+ pairs = []
133
+ for i in range(len(sentences_a)):
134
+ for j in range(len(sentences_b)):
135
+ pairs.append((i, j, similarity_matrix[i, j]))
136
+
137
+ # Sort pairs by highest similarity first
138
+ pairs.sort(key=lambda x: x[2], reverse=True)
139
+
140
+ for i, j, sim in pairs:
141
+ if i not in used_a and j not in used_b and sim >= similarity_threshold:
142
+ matches.append((i, j, sim))
143
+ used_a.add(i)
144
+ used_b.add(j)
145
+
146
+ # --------------------------------------
147
+ # 1) SHOW MATCH TABLE AT THE TOP USING st.dataframe (FILLING THE SCREEN)
148
+ # --------------------------------------
149
+ if len(matches) == 0:
150
+ st.warning("No sentence pairs exceeded the similarity threshold.")
151
+ else:
152
+ # Create a DataFrame for the matched pairs with original order information
153
+ df_matches = pd.DataFrame(
154
+ [
155
+ (i+1, sentences_a[i], j+1, sentences_b[j], round(sim, 3))
156
+ for (i, j, sim) in matches
157
+ ],
158
+ columns=["Set A Order", "Set A Sentence", "Set B Order", "Set B Sentence", "Similarity"]
159
+ )
160
+ st.subheader("Matched Sentences (Above Threshold)")
161
+ st.dataframe(df_matches, use_container_width=True)
162
+
163
+ # --------------------------------------
164
+ # 2) THEN PERFORM T-SNE AND SHOW 3D PLOT
165
+ # --------------------------------------
166
+ perplexity_value = min(5, len(all_sentences) - 1)
167
+
168
+ tsne = TSNE(
169
+ n_components=3,
170
+ perplexity=perplexity_value,
171
+ n_iter=iterations,
172
+ initialization="pca",
173
+ random_state=42
174
+ )
175
+ tsne_results = tsne.fit(all_embeddings)
176
+
177
+ # Prepare DataFrame for Plotly
178
+ df_tsne = pd.DataFrame({
179
+ "Sentence": all_sentences,
180
+ "Set": labels,
181
+ "X": tsne_results[:, 0],
182
+ "Y": tsne_results[:, 1],
183
+ "Z": tsne_results[:, 2]
184
+ })
185
+
186
+ # Create 3D scatter plot with connections
187
+ fig = go.Figure()
188
+
189
+ # Add scatter points for Set A
190
+ fig.add_trace(go.Scatter3d(
191
+ x=df_tsne[df_tsne["Set"] == "Set A"]["X"],
192
+ y=df_tsne[df_tsne["Set"] == "Set A"]["Y"],
193
+ z=df_tsne[df_tsne["Set"] == "Set A"]["Z"],
194
+ text=df_tsne[df_tsne["Set"] == "Set A"]["Sentence"],
195
+ mode='markers',
196
+ name='Set A',
197
+ marker=dict(size=5, color='blue')
198
+ ))
199
+
200
+ # Add scatter points for Set B
201
+ fig.add_trace(go.Scatter3d(
202
+ x=df_tsne[df_tsne["Set"] == "Set B"]["X"],
203
+ y=df_tsne[df_tsne["Set"] == "Set B"]["Y"],
204
+ z=df_tsne[df_tsne["Set"] == "Set B"]["Z"],
205
+ text=df_tsne[df_tsne["Set"] == "Set B"]["Sentence"],
206
+ mode='markers',
207
+ name='Set B',
208
+ marker=dict(size=5, color='red')
209
+ ))
210
+
211
+ # Optionally, add lines for sentence pairs above threshold
212
+ for i, emb_a in enumerate(embeddings_a):
213
+ pos_a = tsne_results[i]
214
+ for j, emb_b in enumerate(embeddings_b):
215
+ sim = similarity_matrix[i, j]
216
+ if sim >= similarity_threshold:
217
+ pos_b = tsne_results[j + len(sentences_a)]
218
+ fig.add_trace(go.Scatter3d(
219
+ x=[pos_a[0], pos_b[0]],
220
+ y=[pos_a[1], pos_b[1]],
221
+ z=[pos_a[2], pos_b[2]],
222
+ mode='lines',
223
+ line=dict(color=f'rgba(150,150,150,{sim})', width=2),
224
+ name=f'Similarity: {sim:.2f}',
225
+ showlegend=False
226
+ ))
227
+
228
+ fig.update_layout(
229
+ title="3D Visualization of Sentence Similarity with Connections",
230
+ width=1200,
231
+ height=800,
232
+ scene=dict(
233
+ xaxis_title="t-SNE Dimension 1",
234
+ yaxis_title="t-SNE Dimension 2",
235
+ zaxis_title="t-SNE Dimension 3"
236
+ )
237
+ )
238
+ st.plotly_chart(fig)
239
+
240
+ # --------------------------------------
241
+ # 3) SIMILARITY HEATMAP
242
+ # --------------------------------------
243
+ fig_heatmap = go.Figure(data=go.Heatmap(
244
+ z=similarity_matrix,
245
+ x=[f"B{i+1}" for i in range(len(sentences_b))],
246
+ y=[f"A{i+1}" for i in range(len(sentences_a))],
247
+ colorscale="Viridis",
248
+ text=np.round(similarity_matrix, 2),
249
+ texttemplate="%{text}",
250
+ textfont={"size": 10},
251
+ hoverongaps=False
252
+ ))
253
+
254
+ fig_heatmap.update_layout(
255
+ title="Similarity Heatmap between Set A and Set B",
256
+ width=None, # Full width
257
+ height=400,
258
+ margin=dict(l=20, r=20, t=40, b=20),
259
+ xaxis_title="Set B Sentences",
260
+ yaxis_title="Set A Sentences"
261
+ )
262
+
263
+ st.plotly_chart(fig_heatmap)
264
+
265
+ else:
266
+ st.warning("Please enter sentences in both Set A and Set B.")