File size: 7,235 Bytes
b531794
29c7915
b531794
 
 
29c7915
 
 
 
 
b531794
 
 
29c7915
b531794
 
 
29c7915
b531794
 
 
 
 
 
 
 
 
 
 
29c7915
 
b531794
 
 
 
 
 
 
 
29c7915
 
 
 
 
 
 
 
 
 
 
 
 
b531794
 
 
 
 
29c7915
b531794
 
 
 
29c7915
b531794
 
 
29c7915
b531794
 
 
 
 
29c7915
 
 
 
 
 
 
 
 
 
 
 
 
 
b531794
 
29c7915
b531794
29c7915
b531794
 
5ef6d29
b531794
29c7915
b531794
 
29c7915
b531794
 
 
29c7915
 
2db1525
29c7915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b531794
29c7915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2db1525
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from transformers.models.bert import BertPreTrainedModel, BertModel
from configuration_multitask import MultiTaskConfig
import torch
import torch.nn.functional as F
from torch import nn
from sklearn.metrics import f1_score
import warnings
import os
warnings.filterwarnings('ignore')
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

class MultiTaskModel(BertPreTrainedModel):
    config_class = MultiTaskConfig

    def __init__(self, config: MultiTaskConfig):
        super().__init__(config)
        self.config = config

        self.bert = BertModel(config)
        self.bert_drop_task = nn.Dropout(config.classifier_dropout)
        self.task_classifier = nn.Linear(config.hidden_size, config.num_task_label)

        self.bert_drop_length = nn.Dropout(config.classifier_dropout)
        self.length_classifier = nn.Linear(config.hidden_size, config.num_length_label)

        self.bert_drop_multiturn = nn.Dropout(config.classifier_dropout)
        self.multiturn_classifier = nn.Linear(config.hidden_size, 1)

        self.bert_drop_multiquestion = nn.Dropout(config.classifier_dropout)
        self.multiquestion_tagger = nn.Linear(config.hidden_size, 1)

    def binary_class_loss_fn(self, logits, labels):
        binary_class_loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.float())
        return binary_class_loss

    def multi_class_loss_fn(self, logits, labels):
        multi_class_loss = F.cross_entropy(logits, labels)
        return multi_class_loss

    def binary_tag_loss_fn(self, logits, labels, token_type_ids):
        binary_class_loss = nn.BCEWithLogitsLoss()
        tagging_loss = 0

        for logit, label, token_type_id in zip(logits, labels, token_type_ids):

            tagging_loss += binary_class_loss(logit[token_type_id == 1], label[token_type_id == 1])

        average_loss = tagging_loss / len(logits)

        return average_loss

    def forward(self, inputs_tokens=None, attention_mask=None, token_type_ids =None,
                multiturn_labels=None, task_labels=None, length_labels=None,
                multiquestion_labels=None):

        outputs = self.bert(input_ids=inputs_tokens, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        last_hidden_state = outputs.last_hidden_state

        multiturn_logits = self.multiturn_classifier(self.bert_drop_multiturn(pooled_output))
        task_logits = self.task_classifier(self.bert_drop_task(pooled_output))
        length_logits = self.length_classifier(self.bert_drop_length(pooled_output))
        multiquestion_logits = self.multiquestion_tagger(self.bert_drop_multiquestion(last_hidden_state))

        if multiturn_labels != None and task_labels != None and length_labels != None and multiquestion_labels != None:

            # Multiturn
            multiturn_threshold = 0.5
            multiturn_pred = (multiturn_logits > multiturn_threshold).float()
            multiturn_pred = multiturn_pred.view(-1)
            multiturn_acc = torch.sum(multiturn_pred == multiturn_labels).item() / len(multiturn_labels) * 100

            # Multiquestion
            multiquestion_thershold = 0.5
            multiquestion_acc =0
            macro_f1 =0
            for logit, label, token_type_id in zip(multiquestion_logits, multiquestion_labels, token_type_ids):
                logit, label = logit[token_type_id == 1], label[token_type_id==1]
                pred = (logit > multiquestion_thershold).float()
                pred = pred.view(-1)
                multiquestion_acc += (torch.sum(pred == label).item() / len(label))
                macro_f1 += f1_score(label.cpu().numpy(),pred.cpu().numpy(),average='macro')
            multiquestion_acc = multiquestion_acc/ len(multiquestion_labels) * 100
            macro_f1 = macro_f1 / len(multiquestion_labels) * 100

            # Task
            _, task_pred = torch.max(task_logits, dim=1)
            task_pred_correct = torch.sum(task_pred == task_labels).item()
            task_acc = task_pred_correct / len(task_labels) * 100

            # Length
            _, length_pred = torch.max(length_logits, dim=1)
            length_pred_correct = torch.sum(length_pred == length_labels).item()
            length_acc = length_pred_correct / len(length_labels) * 100

            multiturn_loss = self.binary_class_loss_fn(multiturn_logits, multiturn_labels)
            task_loss = self.multi_class_loss_fn(task_logits, task_labels)
            length_loss = self.multi_class_loss_fn(length_logits, length_labels)
            multiquestion_loss = self.binary_tag_loss_fn(multiquestion_logits.squeeze(), multiquestion_labels.float(),token_type_ids)

            loss = multiturn_loss + task_loss + length_loss + multiquestion_loss

            return loss, multiturn_acc, task_acc, length_acc, multiquestion_acc, macro_f1
        elif multiturn_labels == None and task_labels == None and length_labels == None and multiquestion_labels == None:

            #[ For pred in local ]
            label_config = {"task_label_dict": {"질의응답": 0, "대화요약": 1, "문서요약": 2},
            "length_label_dict": {"긴문장": 0, "짧은문장": 1, "중간문장": 2},
            "multiturn_label_dict": {"NO": 0, "YES": 1}}

            # Multiturn
            multiturn_threshold = 0.5
            multiturn_pred = (multiturn_logits > multiturn_threshold).float()
            multiturn_pred = multiturn_pred.view(-1)
            multiturn_pred = int(multiturn_pred.item())
            multiturn_pred_result = dict(map(reversed, label_config['multiturn_label_dict'].items())).get(multiturn_pred)

            # Task
            _, task_pred = torch.max(task_logits, dim=1)
            task_pred = int(task_pred.item())
            task_pred_result = dict(map(reversed, label_config['task_label_dict'].items())).get(
                task_pred)

            # Length
            _, length_pred = torch.max(length_logits, dim=1)
            length_pred = int(length_pred.item())
            length_pred_result = dict(map(reversed, label_config['length_label_dict'].items())).get(
                length_pred)

            # Multiquestion
            multiquestion_thershold = 0.5
            multiquestion_logits = multiquestion_logits[token_type_ids==1]
            multiquestion_pred = (multiquestion_logits > multiquestion_thershold).float()
            multiquestion_index_list = (torch.nonzero(multiquestion_pred==1))
            multiquestion_index_list = multiquestion_index_list[:, 0].tolist()
            right_input_tokens = inputs_tokens[token_type_ids==1]
            multiquestion_result = []
            for multiquestion_index in multiquestion_index_list:
                try:
                  multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+10].cpu().tolist())
                except:
                  multiquestion_result.append(right_input_tokens[multiquestion_index:multiquestion_index+5].cpu().tolist())

            return multiturn_pred_result, task_pred_result, length_pred_result, multiquestion_result
            '''
            # [ For pred in torchserve ]
            return multiturn_logits, task_logits, length_logits, multiquestion_logits
            '''