|
from transformers.models.bert import BertPreTrainedModel, BertModel |
|
from configuration_multitask import MultiTaskConfig |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from sklearn.metrics import f1_score |
|
import warnings |
|
import os |
|
warnings.filterwarnings('ignore') |
|
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" |
|
|
|
class MultiTaskModel(BertPreTrainedModel): |
|
config_class = MultiTaskConfig |
|
|
|
def __init__(self, config: MultiTaskConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.bert = BertModel(config) |
|
self.bert_drop_task = nn.Dropout(config.classifier_dropout) |
|
self.task_classifier = nn.Linear(config.hidden_size, config.num_task_label) |
|
|
|
self.bert_drop_length = nn.Dropout(config.classifier_dropout) |
|
self.length_classifier = nn.Linear(config.hidden_size, config.num_length_label) |
|
|
|
self.bert_drop_multiturn = nn.Dropout(config.classifier_dropout) |
|
self.multiturn_classifier = nn.Linear(config.hidden_size, 1) |
|
|
|
self.bert_drop_multiquestion = nn.Dropout(config.classifier_dropout) |
|
self.multiquestion_tagger = nn.Linear(config.hidden_size, 1) |
|
|
|
def binary_class_loss_fn(self, logits, labels): |
|
binary_class_loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.float()) |
|
return binary_class_loss |
|
|
|
def multi_class_loss_fn(self, logits, labels): |
|
multi_class_loss = F.cross_entropy(logits, labels) |
|
return multi_class_loss |
|
|
|
def binary_tag_loss_fn(self, logits, labels, token_type_ids): |
|
binary_class_loss = nn.BCEWithLogitsLoss() |
|
tagging_loss = 0 |
|
|
|
for logit, label, token_type_id in zip(logits, labels, token_type_ids): |
|
|
|
tagging_loss += binary_class_loss(logit[token_type_id == 1], label[token_type_id == 1]) |
|
|
|
average_loss = tagging_loss / len(logits) |
|
|
|
return average_loss |
|
|
|
def forward(self, inputs_tokens=None, attention_mask=None, token_type_ids =None, |
|
multiturn_labels=None, task_labels=None, length_labels=None, |
|
multiquestion_labels=None): |
|
|
|
outputs = self.bert(input_ids=inputs_tokens, attention_mask=attention_mask) |
|
pooled_output = outputs.pooler_output |
|
last_hidden_state = outputs.last_hidden_state |
|
|
|
multiturn_logits = self.multiturn_classifier(self.bert_drop_multiturn(pooled_output)) |
|
task_logits = self.task_classifier(self.bert_drop_task(pooled_output)) |
|
length_logits = self.length_classifier(self.bert_drop_length(pooled_output)) |
|
multiquestion_logits = self.multiquestion_tagger(self.bert_drop_multiquestion(last_hidden_state)) |
|
|
|
if multiturn_labels != None and task_labels != None and length_labels != None and multiquestion_labels != None: |
|
|
|
|
|
multiturn_threshold = 0.5 |
|
multiturn_pred = (multiturn_logits > multiturn_threshold).float() |
|
multiturn_pred = multiturn_pred.view(-1) |
|
multiturn_acc = torch.sum(multiturn_pred == multiturn_labels).item() / len(multiturn_labels) * 100 |
|
|
|
|
|
multiquestion_thershold = 0.5 |
|
multiquestion_acc =0 |
|
macro_f1 =0 |
|
for logit, label, token_type_id in zip(multiquestion_logits, multiquestion_labels, token_type_ids): |
|
logit, label = logit[token_type_id == 1], label[token_type_id==1] |
|
pred = (logit > multiquestion_thershold).float() |
|
pred = pred.view(-1) |
|
multiquestion_acc += (torch.sum(pred == label).item() / len(label)) |
|
macro_f1 += f1_score(label.cpu().numpy(),pred.cpu().numpy(),average='macro') |
|
multiquestion_acc = multiquestion_acc/ len(multiquestion_labels) * 100 |
|
macro_f1 = macro_f1 / len(multiquestion_labels) * 100 |
|
|
|
|
|
_, task_pred = torch.max(task_logits, dim=1) |
|
task_pred_correct = torch.sum(task_pred == task_labels).item() |
|
task_acc = task_pred_correct / len(task_labels) * 100 |
|
|
|
|
|
_, length_pred = torch.max(length_logits, dim=1) |
|
length_pred_correct = torch.sum(length_pred == length_labels).item() |
|
length_acc = length_pred_correct / len(length_labels) * 100 |
|
|
|
multiturn_loss = self.binary_class_loss_fn(multiturn_logits, multiturn_labels) |
|
task_loss = self.multi_class_loss_fn(task_logits, task_labels) |
|
length_loss = self.multi_class_loss_fn(length_logits, length_labels) |
|
multiquestion_loss = self.binary_tag_loss_fn(multiquestion_logits.squeeze(), multiquestion_labels.float(),token_type_ids) |
|
|
|
loss = multiturn_loss + task_loss + length_loss + multiquestion_loss |
|
|
|
return loss, multiturn_acc, task_acc, length_acc, multiquestion_acc, macro_f1 |
|
elif multiturn_labels == None and task_labels == None and length_labels == None and multiquestion_labels == None: |
|
|
|
|
|
label_config = {"task_label_dict": {"질의응답": 0, "대화요약": 1, "문서요약": 2}, |
|
"length_label_dict": {"긴문장": 0, "짧은문장": 1, "중간문장": 2}, |
|
"multiturn_label_dict": {"NO": 0, "YES": 1}} |
|
|
|
|
|
multiturn_threshold = 0.5 |
|
multiturn_pred = (multiturn_logits > multiturn_threshold).float() |
|
multiturn_pred = multiturn_pred.view(-1) |
|
multiturn_pred = int(multiturn_pred.item()) |
|
multiturn_pred_result = dict(map(reversed, label_config['multiturn_label_dict'].items())).get(multiturn_pred) |
|
|
|
|
|
_, task_pred = torch.max(task_logits, dim=1) |
|
task_pred = int(task_pred.item()) |
|
task_pred_result = dict(map(reversed, label_config['task_label_dict'].items())).get( |
|
task_pred) |
|
|
|
|
|
_, length_pred = torch.max(length_logits, dim=1) |
|
length_pred = int(length_pred.item()) |
|
length_pred_result = dict(map(reversed, label_config['length_label_dict'].items())).get( |
|
length_pred) |
|
|
|
|
|
multiquestion_thershold = 0.5 |
|
multiquestion_logits = multiquestion_logits[token_type_ids==1] |
|
multiquestion_pred = (multiquestion_logits > multiquestion_thershold).float() |
|
multiquestion_index_list = (torch.nonzero(multiquestion_pred==1)) |
|
multiquestion_index_list = multiquestion_index_list[:, 0].tolist() |
|
right_input_tokens = inputs_tokens[token_type_ids==1] |
|
multiquestion_result = [] |
|
for multiquestion_index in multiquestion_index_list: |
|
try: |
|
multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+10].cpu().tolist()) |
|
except: |
|
multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+5].cpu().tolist()) |
|
|
|
return multiturn_pred_result, task_pred_result, length_pred_result, multiquestion_result |
|
''' |
|
# [ For pred in torchserve ] |
|
return multiturn_logits, task_logits, length_logits, multiquestion_logits |
|
''' |