Commit
·
b38d2da
1
Parent(s):
d19a591
output serialization
Browse files- handler.py +8 -3
handler.py
CHANGED
@@ -36,9 +36,14 @@ class EndpointHandler:
|
|
36 |
with torch.no_grad():
|
37 |
outputs = self.model(**inputs)
|
38 |
|
39 |
-
# Process outputs -
|
40 |
-
#
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
|
44 |
if __name__ == "__main__":
|
|
|
36 |
with torch.no_grad():
|
37 |
outputs = self.model(**inputs)
|
38 |
|
39 |
+
# Process outputs - convert tensors to serializable format
|
40 |
+
# Extract the last hidden state and convert to list for JSON serialization
|
41 |
+
last_hidden_state = outputs.last_hidden_state
|
42 |
+
|
43 |
+
# Convert to Python list (serializable) - using the mean of the embeddings as a simple approach
|
44 |
+
embedding = last_hidden_state.mean(dim=1).cpu().numpy().tolist()
|
45 |
+
|
46 |
+
return [{"input_text": input_text, "embedding": embedding}]
|
47 |
|
48 |
|
49 |
if __name__ == "__main__":
|