multi_task / modeling_multitask.py
SiHyun970430
Update
2db1525
from transformers.models.bert import BertPreTrainedModel, BertModel
from configuration_multitask import MultiTaskConfig
import torch
import torch.nn.functional as F
from torch import nn
from sklearn.metrics import f1_score
import warnings
import os
warnings.filterwarnings('ignore')
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
class MultiTaskModel(BertPreTrainedModel):
config_class = MultiTaskConfig
def __init__(self, config: MultiTaskConfig):
super().__init__(config)
self.config = config
self.bert = BertModel(config)
self.bert_drop_task = nn.Dropout(config.classifier_dropout)
self.task_classifier = nn.Linear(config.hidden_size, config.num_task_label)
self.bert_drop_length = nn.Dropout(config.classifier_dropout)
self.length_classifier = nn.Linear(config.hidden_size, config.num_length_label)
self.bert_drop_multiturn = nn.Dropout(config.classifier_dropout)
self.multiturn_classifier = nn.Linear(config.hidden_size, 1)
self.bert_drop_multiquestion = nn.Dropout(config.classifier_dropout)
self.multiquestion_tagger = nn.Linear(config.hidden_size, 1)
def binary_class_loss_fn(self, logits, labels):
binary_class_loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.float())
return binary_class_loss
def multi_class_loss_fn(self, logits, labels):
multi_class_loss = F.cross_entropy(logits, labels)
return multi_class_loss
def binary_tag_loss_fn(self, logits, labels, token_type_ids):
binary_class_loss = nn.BCEWithLogitsLoss()
tagging_loss = 0
for logit, label, token_type_id in zip(logits, labels, token_type_ids):
tagging_loss += binary_class_loss(logit[token_type_id == 1], label[token_type_id == 1])
average_loss = tagging_loss / len(logits)
return average_loss
def forward(self, inputs_tokens=None, attention_mask=None, token_type_ids =None,
multiturn_labels=None, task_labels=None, length_labels=None,
multiquestion_labels=None):
outputs = self.bert(input_ids=inputs_tokens, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
last_hidden_state = outputs.last_hidden_state
multiturn_logits = self.multiturn_classifier(self.bert_drop_multiturn(pooled_output))
task_logits = self.task_classifier(self.bert_drop_task(pooled_output))
length_logits = self.length_classifier(self.bert_drop_length(pooled_output))
multiquestion_logits = self.multiquestion_tagger(self.bert_drop_multiquestion(last_hidden_state))
if multiturn_labels != None and task_labels != None and length_labels != None and multiquestion_labels != None:
# Multiturn
multiturn_threshold = 0.5
multiturn_pred = (multiturn_logits > multiturn_threshold).float()
multiturn_pred = multiturn_pred.view(-1)
multiturn_acc = torch.sum(multiturn_pred == multiturn_labels).item() / len(multiturn_labels) * 100
# Multiquestion
multiquestion_thershold = 0.5
multiquestion_acc =0
macro_f1 =0
for logit, label, token_type_id in zip(multiquestion_logits, multiquestion_labels, token_type_ids):
logit, label = logit[token_type_id == 1], label[token_type_id==1]
pred = (logit > multiquestion_thershold).float()
pred = pred.view(-1)
multiquestion_acc += (torch.sum(pred == label).item() / len(label))
macro_f1 += f1_score(label.cpu().numpy(),pred.cpu().numpy(),average='macro')
multiquestion_acc = multiquestion_acc/ len(multiquestion_labels) * 100
macro_f1 = macro_f1 / len(multiquestion_labels) * 100
# Task
_, task_pred = torch.max(task_logits, dim=1)
task_pred_correct = torch.sum(task_pred == task_labels).item()
task_acc = task_pred_correct / len(task_labels) * 100
# Length
_, length_pred = torch.max(length_logits, dim=1)
length_pred_correct = torch.sum(length_pred == length_labels).item()
length_acc = length_pred_correct / len(length_labels) * 100
multiturn_loss = self.binary_class_loss_fn(multiturn_logits, multiturn_labels)
task_loss = self.multi_class_loss_fn(task_logits, task_labels)
length_loss = self.multi_class_loss_fn(length_logits, length_labels)
multiquestion_loss = self.binary_tag_loss_fn(multiquestion_logits.squeeze(), multiquestion_labels.float(),token_type_ids)
loss = multiturn_loss + task_loss + length_loss + multiquestion_loss
return loss, multiturn_acc, task_acc, length_acc, multiquestion_acc, macro_f1
elif multiturn_labels == None and task_labels == None and length_labels == None and multiquestion_labels == None:
#[ For pred in local ]
label_config = {"task_label_dict": {"질의응답": 0, "대화요약": 1, "문서요약": 2},
"length_label_dict": {"긴문장": 0, "짧은문장": 1, "중간문장": 2},
"multiturn_label_dict": {"NO": 0, "YES": 1}}
# Multiturn
multiturn_threshold = 0.5
multiturn_pred = (multiturn_logits > multiturn_threshold).float()
multiturn_pred = multiturn_pred.view(-1)
multiturn_pred = int(multiturn_pred.item())
multiturn_pred_result = dict(map(reversed, label_config['multiturn_label_dict'].items())).get(multiturn_pred)
# Task
_, task_pred = torch.max(task_logits, dim=1)
task_pred = int(task_pred.item())
task_pred_result = dict(map(reversed, label_config['task_label_dict'].items())).get(
task_pred)
# Length
_, length_pred = torch.max(length_logits, dim=1)
length_pred = int(length_pred.item())
length_pred_result = dict(map(reversed, label_config['length_label_dict'].items())).get(
length_pred)
# Multiquestion
multiquestion_thershold = 0.5
multiquestion_logits = multiquestion_logits[token_type_ids==1]
multiquestion_pred = (multiquestion_logits > multiquestion_thershold).float()
multiquestion_index_list = (torch.nonzero(multiquestion_pred==1))
multiquestion_index_list = multiquestion_index_list[:, 0].tolist()
right_input_tokens = inputs_tokens[token_type_ids==1]
multiquestion_result = []
for multiquestion_index in multiquestion_index_list:
try:
multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+10].cpu().tolist())
except:
multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+5].cpu().tolist())
return multiturn_pred_result, task_pred_result, length_pred_result, multiquestion_result
'''
# [ For pred in torchserve ]
return multiturn_logits, task_logits, length_logits, multiquestion_logits
'''