|
from typing import Optional, Tuple, Union, List, Dict |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
from transformers.models.bert import ( |
|
BertConfig, BertModel, BertPreTrainedModel |
|
) |
|
from transformers.models.roberta import ( |
|
RobertaConfig, RobertaModel, RobertaPreTrainedModel |
|
) |
|
from transformers.models.deberta_v2 import ( |
|
DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel |
|
) |
|
from transformers.models.modernbert.modeling_modernbert import ( |
|
ModernBertConfig, ModernBertModel, ModernBertPreTrainedModel, ModernBertPredictionHead |
|
) |
|
|
|
|
|
def fixed_cross_entropy( |
|
source: torch.Tensor, |
|
target: torch.Tensor, |
|
num_items_in_batch: Optional[int] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
reduction = "sum" if num_items_in_batch is not None else "mean" |
|
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) |
|
if reduction == "sum": |
|
if not isinstance(num_items_in_batch, torch.Tensor): |
|
num_items_in_batch = torch.tensor(num_items_in_batch, device=loss.device, dtype=loss.dtype) |
|
elif num_items_in_batch.device != loss.device: |
|
num_items_in_batch = num_items_in_batch.to(loss.device) |
|
loss = loss / num_items_in_batch |
|
return loss |
|
|
|
|
|
class BertForTokenClassification(BertPreTrainedModel): |
|
|
|
def __init__(self, config: BertConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_items_in_batch: Optional[torch.Tensor] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
logits = logits.view(-1, self.num_labels) |
|
labels = labels.view(-1).to(logits.device) |
|
logits = logits.float() |
|
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class CRF(nn.Module): |
|
"""條件隨機場(CRF)層,基於更穩定的實現""" |
|
|
|
def __init__(self, num_labels: int): |
|
super().__init__() |
|
self.num_labels = num_labels |
|
|
|
|
|
self.start_transitions = nn.Parameter(torch.empty(num_labels)) |
|
self.end_transitions = nn.Parameter(torch.empty(num_labels)) |
|
self.transitions = nn.Parameter(torch.empty(num_labels, num_labels)) |
|
|
|
|
|
nn.init.uniform_(self.start_transitions, -0.1, 0.1) |
|
nn.init.uniform_(self.end_transitions, -0.1, 0.1) |
|
nn.init.uniform_(self.transitions, -0.1, 0.1) |
|
|
|
def _compute_log_denominator(self, features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
"""計算配分函數的對數(log of the partition function)""" |
|
seq_len, batch_size, _ = features.shape |
|
|
|
|
|
log_score = self.start_transitions + features[0] |
|
|
|
|
|
for i in range(1, seq_len): |
|
|
|
|
|
|
|
next_score = ( |
|
log_score.unsqueeze(2) + |
|
self.transitions + |
|
features[i].unsqueeze(1) |
|
) |
|
|
|
|
|
next_score = torch.logsumexp(next_score, dim=1) |
|
|
|
|
|
log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score) |
|
|
|
|
|
log_score += self.end_transitions |
|
|
|
|
|
return torch.logsumexp(log_score, dim=1) |
|
|
|
def _compute_log_numerator(self, features: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
"""計算給定標籤序列的得分""" |
|
seq_len, batch_size, _ = features.shape |
|
|
|
|
|
score = self.start_transitions[labels[0]] + features[0, torch.arange(batch_size), labels[0]] |
|
|
|
|
|
for i in range(1, seq_len): |
|
|
|
score += ( |
|
self.transitions[labels[i-1], labels[i]] + |
|
features[i, torch.arange(batch_size), labels[i]] |
|
) * mask[i] |
|
|
|
|
|
seq_lens = mask.sum(dim=0) - 1 |
|
|
|
|
|
last_tags = labels[seq_lens.long(), torch.arange(batch_size)] |
|
|
|
|
|
score += self.end_transitions[last_tags] |
|
|
|
return score |
|
|
|
def forward(self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
""" |
|
計算CRF負對數似然損失 |
|
|
|
參數: |
|
emissions: (seq_len, batch_size, num_labels) 發射得分 |
|
tags: (seq_len, batch_size) 真實標籤 |
|
mask: (seq_len, batch_size) 用於處理變長序列的遮罩 |
|
|
|
返回: |
|
CRF負對數似然損失 |
|
""" |
|
|
|
log_numerator = self._compute_log_numerator(emissions, tags, mask) |
|
log_denominator = self._compute_log_denominator(emissions, mask) |
|
|
|
|
|
loss = torch.mean(log_denominator - log_numerator) |
|
|
|
return loss |
|
|
|
def _viterbi_decode(self, features: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: |
|
"""Viterbi算法解碼,找出最可能的標籤序列""" |
|
seq_len, batch_size, _ = features.shape |
|
|
|
|
|
log_score = self.start_transitions + features[0] |
|
backpointers = torch.zeros((seq_len, batch_size, self.num_labels), dtype=torch.long, device=features.device) |
|
|
|
|
|
for i in range(1, seq_len): |
|
|
|
next_score = log_score.unsqueeze(2) + self.transitions + features[i].unsqueeze(1) |
|
|
|
|
|
next_score, indices = next_score.max(dim=1) |
|
|
|
|
|
backpointers[i] = indices |
|
|
|
|
|
log_score = torch.where(mask[i].unsqueeze(1), next_score, log_score) |
|
|
|
|
|
log_score += self.end_transitions |
|
|
|
|
|
seq_lens = mask.sum(dim=0).long() - 1 |
|
|
|
|
|
best_paths = [] |
|
for seq_idx in range(batch_size): |
|
|
|
best_label = torch.argmax(log_score[seq_idx]).item() |
|
best_path = [best_label] |
|
|
|
|
|
for i in range(seq_lens[seq_idx], 0, -1): |
|
best_label = backpointers[i, seq_idx, best_label].item() |
|
best_path.insert(0, best_label) |
|
|
|
best_paths.append(best_path) |
|
|
|
return best_paths |
|
|
|
def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: |
|
"""使用Viterbi解碼找出最可能的標籤序列""" |
|
|
|
if mask.dtype != torch.bool: |
|
mask = mask.bool() |
|
|
|
with torch.no_grad(): |
|
return self._viterbi_decode(emissions, mask) |
|
|
|
|
|
class BertCRFForTokenClassification(BertPreTrainedModel): |
|
"""BERT模型與CRF層結合用於token分類""" |
|
|
|
def __init__(self, config: BertConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
|
|
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.crf = CRF(config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple[torch.Tensor], Dict[str, torch.Tensor]]: |
|
""" |
|
使用CRF進行序列標注的前向傳播 |
|
|
|
參數: |
|
labels: 標籤序列,用於計算損失 |
|
ignore_index: 忽略的標籤值,通常為-100 |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
sequence_output = self.dropout(sequence_output) |
|
|
|
|
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
|
|
emissions = logits.transpose(0, 1) |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask_t = attention_mask.transpose(0, 1).bool() |
|
else: |
|
attention_mask_t = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) |
|
|
|
|
|
if ignore_index is not None: |
|
labels_mask = (labels != ignore_index) |
|
attention_mask_t = attention_mask_t & labels_mask.transpose(0, 1) |
|
|
|
|
|
crf_labels = labels.clone() |
|
crf_labels[~labels_mask] = 0 |
|
crf_labels_t = crf_labels.transpose(0, 1) |
|
else: |
|
crf_labels_t = labels.transpose(0, 1) |
|
|
|
|
|
loss = self.crf( |
|
emissions=emissions, |
|
tags=crf_labels_t, |
|
mask=attention_mask_t |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def decode( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> List[List[int]]: |
|
""" |
|
解碼最可能的標籤序列 |
|
""" |
|
|
|
with torch.no_grad(): |
|
|
|
outputs = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
return_dict=True, |
|
**kwargs, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
sequence_output = self.dropout(sequence_output) |
|
|
|
|
|
logits = self.classifier(sequence_output) |
|
|
|
|
|
emissions = logits.transpose(0, 1) |
|
|
|
|
|
if attention_mask is not None: |
|
mask = attention_mask.transpose(0, 1).bool() |
|
else: |
|
mask = torch.ones(emissions.shape[:2], dtype=torch.bool, device=emissions.device) |
|
|
|
|
|
best_tags = self.crf.decode(emissions, mask) |
|
|
|
return best_tags |
|
|
|
|
|
class RobertaForTokenClassification(RobertaPreTrainedModel): |
|
|
|
def __init__(self, config: RobertaConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.roberta = RobertaModel(config, add_pooling_layer=False) |
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_items_in_batch: Optional[torch.Tensor] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
|
r""" |
|
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: |
|
|
|
- 0 corresponds to a *sentence A* token, |
|
- 1 corresponds to a *sentence B* token. |
|
This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value |
|
>= 2. All the value in this tensor should be always < type_vocab_size. |
|
|
|
[What are token type IDs?](../glossary#token-type-ids) |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.roberta( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.view(-1, self.num_labels) |
|
labels = labels.view(-1).to(logits.device) |
|
logits = logits.float() |
|
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): |
|
|
|
def __init__(self, config: DebertaV2Config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.deberta = DebertaV2Model(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_items_in_batch: Optional[torch.Tensor] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple, TokenClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.deberta( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.view(-1, self.num_labels) |
|
labels = labels.view(-1).to(logits.device) |
|
logits = logits.float() |
|
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions |
|
) |
|
|
|
|
|
class ModernBertForTokenClassification(ModernBertPreTrainedModel): |
|
|
|
def __init__(self, config: ModernBertConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.model = ModernBertModel(config) |
|
self.head = ModernBertPredictionHead(config) |
|
self.drop = torch.nn.Dropout(config.classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
sliding_window_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
indices: Optional[torch.Tensor] = None, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[int] = None, |
|
batch_size: Optional[int] = None, |
|
seq_len: Optional[int] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_items_in_batch: Optional[torch.Tensor] = None, |
|
ignore_index: int = -100, |
|
**kwargs, |
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
|
r""" |
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers |
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to |
|
far-away tokens in the local attention layers when not using Flash Attention. |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): |
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output. |
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): |
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. |
|
max_seqlen (`int`, *optional*): |
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. |
|
batch_size (`int`, *optional*): |
|
Batch size of the input sequences. Used to pad the output tensors. |
|
seq_len (`int`, *optional*): |
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
self._maybe_set_compile() |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
sliding_window_mask=sliding_window_mask, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
indices=indices, |
|
cu_seqlens=cu_seqlens, |
|
max_seqlen=max_seqlen, |
|
batch_size=batch_size, |
|
seq_len=seq_len, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
last_hidden_state = outputs[0] |
|
|
|
last_hidden_state = self.head(last_hidden_state) |
|
last_hidden_state = self.drop(last_hidden_state) |
|
logits = self.classifier(last_hidden_state) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.view(-1, self.num_labels) |
|
labels = labels.view(-1).to(logits.device) |
|
logits = logits.float() |
|
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |