Upload ProkBertForMaskedLM
Browse files- config.json +2 -2
- models.py +5 -15
config.json
CHANGED
@@ -5,9 +5,9 @@
|
|
5 |
],
|
6 |
"attention_probs_dropout_prob": 0.1,
|
7 |
"auto_map": {
|
8 |
-
"AutoConfig": "
|
9 |
"AutoModel": "neuralbioinfo/prokbert-mini-long--models.ProkBertModel",
|
10 |
-
"AutoModelForMaskedLM": "
|
11 |
"AutoModelForSequenceClassification": "neuralbioinfo/prokbert-mini-long--models.ProkBertForSequenceClassification"
|
12 |
},
|
13 |
"classification_dropout_rate": 0.1,
|
|
|
5 |
],
|
6 |
"attention_probs_dropout_prob": 0.1,
|
7 |
"auto_map": {
|
8 |
+
"AutoConfig": "models.ProkBertConfig",
|
9 |
"AutoModel": "neuralbioinfo/prokbert-mini-long--models.ProkBertModel",
|
10 |
+
"AutoModelForMaskedLM": "models.ProkBertForMaskedLM",
|
11 |
"AutoModelForSequenceClassification": "neuralbioinfo/prokbert-mini-long--models.ProkBertForSequenceClassification"
|
12 |
},
|
13 |
"classification_dropout_rate": 0.1,
|
models.py
CHANGED
@@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|
9 |
from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
|
10 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
11 |
from transformers.utils.hub import cached_file
|
12 |
-
|
13 |
|
14 |
class BertForBinaryClassificationWithPooling(nn.Module):
|
15 |
"""
|
@@ -130,18 +130,6 @@ class BertForBinaryClassificationWithPooling(nn.Module):
|
|
130 |
|
131 |
|
132 |
|
133 |
-
class OldProkBertConfig(MegatronBertConfig):
|
134 |
-
|
135 |
-
model_type = "prokbert"
|
136 |
-
def __init__(
|
137 |
-
self,
|
138 |
-
kmer: int = 6,
|
139 |
-
shift: int = 1,
|
140 |
-
**kwargs,
|
141 |
-
):
|
142 |
-
super().__init__(**kwargs)
|
143 |
-
self.kmer=kmer
|
144 |
-
self.shift=shift
|
145 |
|
146 |
class ProkBertConfig(MegatronBertConfig):
|
147 |
model_type = "prokbert"
|
@@ -283,8 +271,10 @@ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
|
|
283 |
# Classification head
|
284 |
pooled_output = self.dropout(pooled_output)
|
285 |
logits = self.classifier(pooled_output)
|
286 |
-
loss =
|
287 |
-
|
|
|
|
|
288 |
classification_output = SequenceClassifierOutput(
|
289 |
loss=loss,
|
290 |
logits=logits,
|
|
|
9 |
from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
|
10 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
11 |
from transformers.utils.hub import cached_file
|
12 |
+
from prokbert.training_utils import compute_metrics_eval_prediction
|
13 |
|
14 |
class BertForBinaryClassificationWithPooling(nn.Module):
|
15 |
"""
|
|
|
130 |
|
131 |
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
class ProkBertConfig(MegatronBertConfig):
|
135 |
model_type = "prokbert"
|
|
|
271 |
# Classification head
|
272 |
pooled_output = self.dropout(pooled_output)
|
273 |
logits = self.classifier(pooled_output)
|
274 |
+
loss = None
|
275 |
+
if labels is not None:
|
276 |
+
loss = self.loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
277 |
+
|
278 |
classification_output = SequenceClassifierOutput(
|
279 |
loss=loss,
|
280 |
logits=logits,
|