Upload ProkBertForMaskedLM
Browse files
models.py
CHANGED
@@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|
9 |
from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
|
10 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
11 |
from transformers.utils.hub import cached_file
|
12 |
-
from prokbert.training_utils import compute_metrics_eval_prediction
|
13 |
|
14 |
class BertForBinaryClassificationWithPooling(nn.Module):
|
15 |
"""
|
@@ -274,7 +274,7 @@ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
|
|
274 |
loss = None
|
275 |
if labels is not None:
|
276 |
loss = self.loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
277 |
-
|
278 |
classification_output = SequenceClassifierOutput(
|
279 |
loss=loss,
|
280 |
logits=logits,
|
|
|
9 |
from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
|
10 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
11 |
from transformers.utils.hub import cached_file
|
12 |
+
#from prokbert.training_utils import compute_metrics_eval_prediction
|
13 |
|
14 |
class BertForBinaryClassificationWithPooling(nn.Module):
|
15 |
"""
|
|
|
274 |
loss = None
|
275 |
if labels is not None:
|
276 |
loss = self.loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
277 |
+
|
278 |
classification_output = SequenceClassifierOutput(
|
279 |
loss=loss,
|
280 |
logits=logits,
|