from transformers.models.bert import BertPreTrainedModel, BertModel from configuration_multitask import MultiTaskConfig import torch import torch.nn.functional as F from torch import nn from sklearn.metrics import f1_score import warnings import os warnings.filterwarnings('ignore') os.environ['CUDA_LAUNCH_BLOCKING'] = "1" class MultiTaskModel(BertPreTrainedModel): config_class = MultiTaskConfig def __init__(self, config: MultiTaskConfig): super().__init__(config) self.config = config self.bert = BertModel(config) self.bert_drop_task = nn.Dropout(config.classifier_dropout) self.task_classifier = nn.Linear(config.hidden_size, config.num_task_label) self.bert_drop_length = nn.Dropout(config.classifier_dropout) self.length_classifier = nn.Linear(config.hidden_size, config.num_length_label) self.bert_drop_multiturn = nn.Dropout(config.classifier_dropout) self.multiturn_classifier = nn.Linear(config.hidden_size, 1) self.bert_drop_multiquestion = nn.Dropout(config.classifier_dropout) self.multiquestion_tagger = nn.Linear(config.hidden_size, 1) def binary_class_loss_fn(self, logits, labels): binary_class_loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.float()) return binary_class_loss def multi_class_loss_fn(self, logits, labels): multi_class_loss = F.cross_entropy(logits, labels) return multi_class_loss def binary_tag_loss_fn(self, logits, labels, token_type_ids): binary_class_loss = nn.BCEWithLogitsLoss() tagging_loss = 0 for logit, label, token_type_id in zip(logits, labels, token_type_ids): tagging_loss += binary_class_loss(logit[token_type_id == 1], label[token_type_id == 1]) average_loss = tagging_loss / len(logits) return average_loss def forward(self, inputs_tokens=None, attention_mask=None, token_type_ids =None, multiturn_labels=None, task_labels=None, length_labels=None, multiquestion_labels=None): outputs = self.bert(input_ids=inputs_tokens, attention_mask=attention_mask) pooled_output = outputs.pooler_output last_hidden_state = outputs.last_hidden_state multiturn_logits = self.multiturn_classifier(self.bert_drop_multiturn(pooled_output)) task_logits = self.task_classifier(self.bert_drop_task(pooled_output)) length_logits = self.length_classifier(self.bert_drop_length(pooled_output)) multiquestion_logits = self.multiquestion_tagger(self.bert_drop_multiquestion(last_hidden_state)) if multiturn_labels != None and task_labels != None and length_labels != None and multiquestion_labels != None: # Multiturn multiturn_threshold = 0.5 multiturn_pred = (multiturn_logits > multiturn_threshold).float() multiturn_pred = multiturn_pred.view(-1) multiturn_acc = torch.sum(multiturn_pred == multiturn_labels).item() / len(multiturn_labels) * 100 # Multiquestion multiquestion_thershold = 0.5 multiquestion_acc =0 macro_f1 =0 for logit, label, token_type_id in zip(multiquestion_logits, multiquestion_labels, token_type_ids): logit, label = logit[token_type_id == 1], label[token_type_id==1] pred = (logit > multiquestion_thershold).float() pred = pred.view(-1) multiquestion_acc += (torch.sum(pred == label).item() / len(label)) macro_f1 += f1_score(label.cpu().numpy(),pred.cpu().numpy(),average='macro') multiquestion_acc = multiquestion_acc/ len(multiquestion_labels) * 100 macro_f1 = macro_f1 / len(multiquestion_labels) * 100 # Task _, task_pred = torch.max(task_logits, dim=1) task_pred_correct = torch.sum(task_pred == task_labels).item() task_acc = task_pred_correct / len(task_labels) * 100 # Length _, length_pred = torch.max(length_logits, dim=1) length_pred_correct = torch.sum(length_pred == length_labels).item() length_acc = length_pred_correct / len(length_labels) * 100 multiturn_loss = self.binary_class_loss_fn(multiturn_logits, multiturn_labels) task_loss = self.multi_class_loss_fn(task_logits, task_labels) length_loss = self.multi_class_loss_fn(length_logits, length_labels) multiquestion_loss = self.binary_tag_loss_fn(multiquestion_logits.squeeze(), multiquestion_labels.float(),token_type_ids) loss = multiturn_loss + task_loss + length_loss + multiquestion_loss return loss, multiturn_acc, task_acc, length_acc, multiquestion_acc, macro_f1 elif multiturn_labels == None and task_labels == None and length_labels == None and multiquestion_labels == None: #[ For pred in local ] label_config = {"task_label_dict": {"질의응답": 0, "대화요약": 1, "문서요약": 2}, "length_label_dict": {"긴문장": 0, "짧은문장": 1, "중간문장": 2}, "multiturn_label_dict": {"NO": 0, "YES": 1}} # Multiturn multiturn_threshold = 0.5 multiturn_pred = (multiturn_logits > multiturn_threshold).float() multiturn_pred = multiturn_pred.view(-1) multiturn_pred = int(multiturn_pred.item()) multiturn_pred_result = dict(map(reversed, label_config['multiturn_label_dict'].items())).get(multiturn_pred) # Task _, task_pred = torch.max(task_logits, dim=1) task_pred = int(task_pred.item()) task_pred_result = dict(map(reversed, label_config['task_label_dict'].items())).get( task_pred) # Length _, length_pred = torch.max(length_logits, dim=1) length_pred = int(length_pred.item()) length_pred_result = dict(map(reversed, label_config['length_label_dict'].items())).get( length_pred) # Multiquestion multiquestion_thershold = 0.5 multiquestion_logits = multiquestion_logits[token_type_ids==1] multiquestion_pred = (multiquestion_logits > multiquestion_thershold).float() multiquestion_index_list = (torch.nonzero(multiquestion_pred==1)) multiquestion_index_list = multiquestion_index_list[:, 0].tolist() right_input_tokens = inputs_tokens[token_type_ids==1] multiquestion_result = [] for multiquestion_index in multiquestion_index_list: try: multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+10].cpu().tolist()) except: multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+5].cpu().tolist()) return multiturn_pred_result, task_pred_result, length_pred_result, multiquestion_result ''' # [ For pred in torchserve ] return multiturn_logits, task_logits, length_logits, multiquestion_logits '''