ligeti commited on
Commit
7fc1e3e
·
verified ·
1 Parent(s): c2af699

Upload ProkBertForMaskedLM

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. models.py +5 -15
config.json CHANGED
@@ -5,9 +5,9 @@
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
- "AutoConfig": "neuralbioinfo/prokbert-mini-long--models.ProkBertConfig",
9
  "AutoModel": "neuralbioinfo/prokbert-mini-long--models.ProkBertModel",
10
- "AutoModelForMaskedLM": "neuralbioinfo/prokbert-mini-long--models.ProkBertForMaskedLM",
11
  "AutoModelForSequenceClassification": "neuralbioinfo/prokbert-mini-long--models.ProkBertForSequenceClassification"
12
  },
13
  "classification_dropout_rate": 0.1,
 
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
+ "AutoConfig": "models.ProkBertConfig",
9
  "AutoModel": "neuralbioinfo/prokbert-mini-long--models.ProkBertModel",
10
+ "AutoModelForMaskedLM": "models.ProkBertForMaskedLM",
11
  "AutoModelForSequenceClassification": "neuralbioinfo/prokbert-mini-long--models.ProkBertForSequenceClassification"
12
  },
13
  "classification_dropout_rate": 0.1,
models.py CHANGED
@@ -9,7 +9,7 @@ import torch.nn.functional as F
9
  from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
10
  from transformers.modeling_outputs import SequenceClassifierOutput
11
  from transformers.utils.hub import cached_file
12
-
13
 
14
  class BertForBinaryClassificationWithPooling(nn.Module):
15
  """
@@ -130,18 +130,6 @@ class BertForBinaryClassificationWithPooling(nn.Module):
130
 
131
 
132
 
133
- class OldProkBertConfig(MegatronBertConfig):
134
-
135
- model_type = "prokbert"
136
- def __init__(
137
- self,
138
- kmer: int = 6,
139
- shift: int = 1,
140
- **kwargs,
141
- ):
142
- super().__init__(**kwargs)
143
- self.kmer=kmer
144
- self.shift=shift
145
 
146
  class ProkBertConfig(MegatronBertConfig):
147
  model_type = "prokbert"
@@ -283,8 +271,10 @@ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
283
  # Classification head
284
  pooled_output = self.dropout(pooled_output)
285
  logits = self.classifier(pooled_output)
286
- loss = self.loss_fct(logits.view(-1, 2), labels.view(-1))
287
-
 
 
288
  classification_output = SequenceClassifierOutput(
289
  loss=loss,
290
  logits=logits,
 
9
  from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
10
  from transformers.modeling_outputs import SequenceClassifierOutput
11
  from transformers.utils.hub import cached_file
12
+ from prokbert.training_utils import compute_metrics_eval_prediction
13
 
14
  class BertForBinaryClassificationWithPooling(nn.Module):
15
  """
 
130
 
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  class ProkBertConfig(MegatronBertConfig):
135
  model_type = "prokbert"
 
271
  # Classification head
272
  pooled_output = self.dropout(pooled_output)
273
  logits = self.classifier(pooled_output)
274
+ loss = None
275
+ if labels is not None:
276
+ loss = self.loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
277
+
278
  classification_output = SequenceClassifierOutput(
279
  loss=loss,
280
  logits=logits,