File size: 7,235 Bytes
b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 29c7915 b531794 5ef6d29 b531794 29c7915 b531794 29c7915 b531794 29c7915 2db1525 29c7915 b531794 29c7915 2db1525 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from transformers.models.bert import BertPreTrainedModel, BertModel
from configuration_multitask import MultiTaskConfig
import torch
import torch.nn.functional as F
from torch import nn
from sklearn.metrics import f1_score
import warnings
import os
warnings.filterwarnings('ignore')
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
class MultiTaskModel(BertPreTrainedModel):
config_class = MultiTaskConfig
def __init__(self, config: MultiTaskConfig):
super().__init__(config)
self.config = config
self.bert = BertModel(config)
self.bert_drop_task = nn.Dropout(config.classifier_dropout)
self.task_classifier = nn.Linear(config.hidden_size, config.num_task_label)
self.bert_drop_length = nn.Dropout(config.classifier_dropout)
self.length_classifier = nn.Linear(config.hidden_size, config.num_length_label)
self.bert_drop_multiturn = nn.Dropout(config.classifier_dropout)
self.multiturn_classifier = nn.Linear(config.hidden_size, 1)
self.bert_drop_multiquestion = nn.Dropout(config.classifier_dropout)
self.multiquestion_tagger = nn.Linear(config.hidden_size, 1)
def binary_class_loss_fn(self, logits, labels):
binary_class_loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.float())
return binary_class_loss
def multi_class_loss_fn(self, logits, labels):
multi_class_loss = F.cross_entropy(logits, labels)
return multi_class_loss
def binary_tag_loss_fn(self, logits, labels, token_type_ids):
binary_class_loss = nn.BCEWithLogitsLoss()
tagging_loss = 0
for logit, label, token_type_id in zip(logits, labels, token_type_ids):
tagging_loss += binary_class_loss(logit[token_type_id == 1], label[token_type_id == 1])
average_loss = tagging_loss / len(logits)
return average_loss
def forward(self, inputs_tokens=None, attention_mask=None, token_type_ids =None,
multiturn_labels=None, task_labels=None, length_labels=None,
multiquestion_labels=None):
outputs = self.bert(input_ids=inputs_tokens, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
last_hidden_state = outputs.last_hidden_state
multiturn_logits = self.multiturn_classifier(self.bert_drop_multiturn(pooled_output))
task_logits = self.task_classifier(self.bert_drop_task(pooled_output))
length_logits = self.length_classifier(self.bert_drop_length(pooled_output))
multiquestion_logits = self.multiquestion_tagger(self.bert_drop_multiquestion(last_hidden_state))
if multiturn_labels != None and task_labels != None and length_labels != None and multiquestion_labels != None:
# Multiturn
multiturn_threshold = 0.5
multiturn_pred = (multiturn_logits > multiturn_threshold).float()
multiturn_pred = multiturn_pred.view(-1)
multiturn_acc = torch.sum(multiturn_pred == multiturn_labels).item() / len(multiturn_labels) * 100
# Multiquestion
multiquestion_thershold = 0.5
multiquestion_acc =0
macro_f1 =0
for logit, label, token_type_id in zip(multiquestion_logits, multiquestion_labels, token_type_ids):
logit, label = logit[token_type_id == 1], label[token_type_id==1]
pred = (logit > multiquestion_thershold).float()
pred = pred.view(-1)
multiquestion_acc += (torch.sum(pred == label).item() / len(label))
macro_f1 += f1_score(label.cpu().numpy(),pred.cpu().numpy(),average='macro')
multiquestion_acc = multiquestion_acc/ len(multiquestion_labels) * 100
macro_f1 = macro_f1 / len(multiquestion_labels) * 100
# Task
_, task_pred = torch.max(task_logits, dim=1)
task_pred_correct = torch.sum(task_pred == task_labels).item()
task_acc = task_pred_correct / len(task_labels) * 100
# Length
_, length_pred = torch.max(length_logits, dim=1)
length_pred_correct = torch.sum(length_pred == length_labels).item()
length_acc = length_pred_correct / len(length_labels) * 100
multiturn_loss = self.binary_class_loss_fn(multiturn_logits, multiturn_labels)
task_loss = self.multi_class_loss_fn(task_logits, task_labels)
length_loss = self.multi_class_loss_fn(length_logits, length_labels)
multiquestion_loss = self.binary_tag_loss_fn(multiquestion_logits.squeeze(), multiquestion_labels.float(),token_type_ids)
loss = multiturn_loss + task_loss + length_loss + multiquestion_loss
return loss, multiturn_acc, task_acc, length_acc, multiquestion_acc, macro_f1
elif multiturn_labels == None and task_labels == None and length_labels == None and multiquestion_labels == None:
#[ For pred in local ]
label_config = {"task_label_dict": {"질의응답": 0, "대화요약": 1, "문서요약": 2},
"length_label_dict": {"긴문장": 0, "짧은문장": 1, "중간문장": 2},
"multiturn_label_dict": {"NO": 0, "YES": 1}}
# Multiturn
multiturn_threshold = 0.5
multiturn_pred = (multiturn_logits > multiturn_threshold).float()
multiturn_pred = multiturn_pred.view(-1)
multiturn_pred = int(multiturn_pred.item())
multiturn_pred_result = dict(map(reversed, label_config['multiturn_label_dict'].items())).get(multiturn_pred)
# Task
_, task_pred = torch.max(task_logits, dim=1)
task_pred = int(task_pred.item())
task_pred_result = dict(map(reversed, label_config['task_label_dict'].items())).get(
task_pred)
# Length
_, length_pred = torch.max(length_logits, dim=1)
length_pred = int(length_pred.item())
length_pred_result = dict(map(reversed, label_config['length_label_dict'].items())).get(
length_pred)
# Multiquestion
multiquestion_thershold = 0.5
multiquestion_logits = multiquestion_logits[token_type_ids==1]
multiquestion_pred = (multiquestion_logits > multiquestion_thershold).float()
multiquestion_index_list = (torch.nonzero(multiquestion_pred==1))
multiquestion_index_list = multiquestion_index_list[:, 0].tolist()
right_input_tokens = inputs_tokens[token_type_ids==1]
multiquestion_result = []
for multiquestion_index in multiquestion_index_list:
try:
multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+10].cpu().tolist())
except:
multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+5].cpu().tolist())
return multiturn_pred_result, task_pred_result, length_pred_result, multiquestion_result
'''
# [ For pred in torchserve ]
return multiturn_logits, task_logits, length_logits, multiquestion_logits
''' |