Token Classification
Transformers
PyTorch
English
bert
fill-mask
dejanseo commited on
Commit
f0c6f33
·
verified ·
1 Parent(s): c36177e

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -0
handler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizer, BertForTokenClassification
3
+
4
+ # Initialize the model and tokenizer
5
+ model_name = "dejanseo/LinkBERT"
6
+ tokenizer = BertTokenizer.from_pretrained(model_name)
7
+ model = BertForTokenClassification.from_pretrained(model_name)
8
+
9
+ def model_init(path, device='cpu'):
10
+ """Initialize model."""
11
+ model.to(device)
12
+ model.eval()
13
+ return model
14
+
15
+ # This function will be called to load the model
16
+ def init():
17
+ # If your model requires any specific initialization, handle it here
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+ model_init(model, device=device)
20
+
21
+ # This function will be called to process requests
22
+ def process(inputs):
23
+ # Preprocess input data
24
+ input_data = inputs["inputs"]
25
+ inputs_tensor = tokenizer(input_data, return_tensors="pt", add_special_tokens=True)
26
+ input_ids = inputs_tensor["input_ids"]
27
+
28
+ # Run model
29
+ with torch.no_grad():
30
+ outputs = model(input_ids)
31
+ predictions = torch.argmax(outputs.logits, dim=-1)
32
+
33
+ # Postprocess model outputs
34
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens
35
+ predictions = predictions[0][1:-1]
36
+ result = []
37
+ for token, pred in zip(tokens, predictions):
38
+ if pred.item() == 1:
39
+ result.append(f"<u>{token}</u>")
40
+ else:
41
+ result.append(token)
42
+
43
+ # Join tokens back into a string
44
+ reconstructed_text = " ".join(result).replace(" ##", "")
45
+
46
+ return {"result": reconstructed_text}
47
+
48
+ # Note: The actual function signatures for init() and process() might need to be adapted based on Hugging Face's requirements.