add autoconfig
Browse files- bert_layers.py +8 -0
bert_layers.py
CHANGED
@@ -25,6 +25,9 @@ from .bert_padding import (index_first_axis,
|
|
25 |
index_put_first_axis, pad_input,
|
26 |
unpad_input, unpad_input_only)
|
27 |
|
|
|
|
|
|
|
28 |
try:
|
29 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
30 |
except ImportError as e:
|
@@ -683,6 +686,8 @@ class BertOnlyNSPHead(nn.Module):
|
|
683 |
|
684 |
|
685 |
class BertForMaskedLM(BertPreTrainedModel):
|
|
|
|
|
686 |
|
687 |
def __init__(self, config):
|
688 |
super().__init__(config)
|
@@ -812,6 +817,9 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
812 |
|
813 |
|
814 |
class BertForSequenceClassification(BertPreTrainedModel):
|
|
|
|
|
|
|
815 |
"""Bert Model transformer with a sequence classification/regression head.
|
816 |
|
817 |
This head is just a linear layer on top of the pooled output. Used for,
|
|
|
25 |
index_put_first_axis, pad_input,
|
26 |
unpad_input, unpad_input_only)
|
27 |
|
28 |
+
from .configuration_bert import BertConfig
|
29 |
+
|
30 |
+
|
31 |
try:
|
32 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
33 |
except ImportError as e:
|
|
|
686 |
|
687 |
|
688 |
class BertForMaskedLM(BertPreTrainedModel):
|
689 |
+
|
690 |
+
config_class = BertConfig
|
691 |
|
692 |
def __init__(self, config):
|
693 |
super().__init__(config)
|
|
|
817 |
|
818 |
|
819 |
class BertForSequenceClassification(BertPreTrainedModel):
|
820 |
+
|
821 |
+
config_class = BertConfig
|
822 |
+
|
823 |
"""Bert Model transformer with a sequence classification/regression head.
|
824 |
|
825 |
This head is just a linear layer on top of the pooled output. Used for,
|