|
--- |
|
license: mit |
|
language: |
|
- en |
|
library_name: transformers |
|
pipeline_tag: document-question-answering |
|
--- |
|
|
|
Fine tuned on DocVQA Dataset 40000 questions |
|
|
|
```python |
|
import json |
|
from glob import glob |
|
from transformers import AutoProcessor, AutoModelForDocumentQuestionAnswering |
|
|
|
import torch |
|
import numpy as np |
|
|
|
model_name = "TusharGoel/LayoutLMv2-finetuned-docvqa" |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_name) |
|
|
|
|
|
def pipeline(question, words, boxes, **kwargs): |
|
|
|
images = kwargs["images"] |
|
try: |
|
encoding = processor( |
|
images, question, words,boxes = boxes, return_token_type_ids=True, return_tensors="pt", truncation = True |
|
) |
|
word_ids = encoding.word_ids(0) |
|
|
|
outputs = model(**encoding) |
|
|
|
start_scores = outputs.start_logits |
|
end_scores = outputs.end_logits |
|
|
|
|
|
start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)] |
|
answer = " ".join(words[start : end + 1]) |
|
|
|
start_scores, end_scores = start_scores.detach().numpy(), end_scores.detach().numpy() |
|
undesired_tokens = encoding['attention_mask'] |
|
undesired_tokens_mask = undesired_tokens == 0.0 |
|
|
|
start_ = np.where(undesired_tokens_mask, -10000.0, start_scores) |
|
end_ = np.where(undesired_tokens_mask, -10000.0, end_scores) |
|
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True))) |
|
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True))) |
|
|
|
outer = np.matmul(np.expand_dims(start_, -1), np.expand_dims(end_, 1)) |
|
max_answer_len = 20 |
|
candidates = np.tril(np.triu(outer), max_answer_len - 1) |
|
scores_flat = candidates.flatten() |
|
|
|
idx_sort = [np.argmax(scores_flat)] |
|
start, end = np.unravel_index(idx_sort, candidates.shape)[1:] |
|
|
|
scores = candidates[0, start, end] |
|
score = scores[0] |
|
except Exception as e: |
|
answer, score = "", 0.0 |
|
return answer, score |
|
``` |