Token Classification
Transformers
PyTorch
English
bert
fill-mask
dejanseo commited on
Commit
173d81c
·
verified ·
1 Parent(s): 6df8d56

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -6
handler.py CHANGED
@@ -1,14 +1,23 @@
1
- from transformers import AutoModelForTokenClassification, AutoTokenizer
2
  import torch
3
  from typing import Dict, List, Any
4
 
5
  class EndpointHandler:
6
  def __init__(self, path: str = "dejanseo/LinkBERT"):
7
- # Initialize tokenizer and model with the specified path
8
- self.tokenizer = AutoTokenizer.from_pretrained(path)
9
- self.model = AutoModelForTokenClassification.from_pretrained(path)
 
 
 
 
 
 
10
  self.model.eval() # Set model to evaluation mode
11
 
 
 
 
12
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
  # Extract input text from the request
14
  inputs = data.get("inputs", "")
@@ -29,7 +38,7 @@ class EndpointHandler:
29
  # Reconstruct the text with annotations for token classification
30
  result = []
31
  for token, pred in zip(tokens, predictions):
32
- if pred == 1: # Assuming '1' is the label for the class of interest
33
  result.append(f"<u>{token}</u>")
34
  else:
35
  result.append(token)
@@ -39,4 +48,5 @@ class EndpointHandler:
39
  # Return the processed text in a structured format
40
  return [{"text": reconstructed_text}]
41
 
42
- # Note: You'll need to replace 'path' with the actual path or identifier of your model when initializing the EndpointHandler.
 
 
1
+ from transformers import BertForTokenClassification, BertTokenizer, AutoConfig
2
  import torch
3
  from typing import Dict, List, Any
4
 
5
  class EndpointHandler:
6
  def __init__(self, path: str = "dejanseo/LinkBERT"):
7
+ # Load the configuration from the saved model
8
+ self.config = AutoConfig.from_pretrained(path)
9
+
10
+ # Make sure to specify the correct model name for bert-large-cased
11
+ # Adjust num_labels according to your model's configuration
12
+ self.model = BertForTokenClassification.from_pretrained(
13
+ path,
14
+ config=self.config
15
+ )
16
  self.model.eval() # Set model to evaluation mode
17
 
18
+ # Load the tokenizer for bert-large-cased
19
+ self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased")
20
+
21
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
  # Extract input text from the request
23
  inputs = data.get("inputs", "")
 
38
  # Reconstruct the text with annotations for token classification
39
  result = []
40
  for token, pred in zip(tokens, predictions):
41
+ if pred == 1: # Adjust this based on your classification needs
42
  result.append(f"<u>{token}</u>")
43
  else:
44
  result.append(token)
 
48
  # Return the processed text in a structured format
49
  return [{"text": reconstructed_text}]
50
 
51
+ # Note: Ensure the path "dejanseo/LinkBERT" is correctly pointing to your model's location
52
+ # If the model is locally saved, adjust the path accordingly