Update handler.py
Browse files- handler.py +16 -6
handler.py
CHANGED
@@ -1,14 +1,23 @@
|
|
1 |
-
from transformers import
|
2 |
import torch
|
3 |
from typing import Dict, List, Any
|
4 |
|
5 |
class EndpointHandler:
|
6 |
def __init__(self, path: str = "dejanseo/LinkBERT"):
|
7 |
-
#
|
8 |
-
self.
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
self.model.eval() # Set model to evaluation mode
|
11 |
|
|
|
|
|
|
|
12 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
13 |
# Extract input text from the request
|
14 |
inputs = data.get("inputs", "")
|
@@ -29,7 +38,7 @@ class EndpointHandler:
|
|
29 |
# Reconstruct the text with annotations for token classification
|
30 |
result = []
|
31 |
for token, pred in zip(tokens, predictions):
|
32 |
-
if pred == 1: #
|
33 |
result.append(f"<u>{token}</u>")
|
34 |
else:
|
35 |
result.append(token)
|
@@ -39,4 +48,5 @@ class EndpointHandler:
|
|
39 |
# Return the processed text in a structured format
|
40 |
return [{"text": reconstructed_text}]
|
41 |
|
42 |
-
# Note:
|
|
|
|
1 |
+
from transformers import BertForTokenClassification, BertTokenizer, AutoConfig
|
2 |
import torch
|
3 |
from typing import Dict, List, Any
|
4 |
|
5 |
class EndpointHandler:
|
6 |
def __init__(self, path: str = "dejanseo/LinkBERT"):
|
7 |
+
# Load the configuration from the saved model
|
8 |
+
self.config = AutoConfig.from_pretrained(path)
|
9 |
+
|
10 |
+
# Make sure to specify the correct model name for bert-large-cased
|
11 |
+
# Adjust num_labels according to your model's configuration
|
12 |
+
self.model = BertForTokenClassification.from_pretrained(
|
13 |
+
path,
|
14 |
+
config=self.config
|
15 |
+
)
|
16 |
self.model.eval() # Set model to evaluation mode
|
17 |
|
18 |
+
# Load the tokenizer for bert-large-cased
|
19 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased")
|
20 |
+
|
21 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
22 |
# Extract input text from the request
|
23 |
inputs = data.get("inputs", "")
|
|
|
38 |
# Reconstruct the text with annotations for token classification
|
39 |
result = []
|
40 |
for token, pred in zip(tokens, predictions):
|
41 |
+
if pred == 1: # Adjust this based on your classification needs
|
42 |
result.append(f"<u>{token}</u>")
|
43 |
else:
|
44 |
result.append(token)
|
|
|
48 |
# Return the processed text in a structured format
|
49 |
return [{"text": reconstructed_text}]
|
50 |
|
51 |
+
# Note: Ensure the path "dejanseo/LinkBERT" is correctly pointing to your model's location
|
52 |
+
# If the model is locally saved, adjust the path accordingly
|