Sifal commited on
Commit
9187183
·
verified ·
1 Parent(s): d064dec

add autoconfig

Browse files
Files changed (1) hide show
  1. bert_layers.py +8 -0
bert_layers.py CHANGED
@@ -25,6 +25,9 @@ from .bert_padding import (index_first_axis,
25
  index_put_first_axis, pad_input,
26
  unpad_input, unpad_input_only)
27
 
 
 
 
28
  try:
29
  from .flash_attn_triton import flash_attn_qkvpacked_func
30
  except ImportError as e:
@@ -683,6 +686,8 @@ class BertOnlyNSPHead(nn.Module):
683
 
684
 
685
  class BertForMaskedLM(BertPreTrainedModel):
 
 
686
 
687
  def __init__(self, config):
688
  super().__init__(config)
@@ -812,6 +817,9 @@ class BertForMaskedLM(BertPreTrainedModel):
812
 
813
 
814
  class BertForSequenceClassification(BertPreTrainedModel):
 
 
 
815
  """Bert Model transformer with a sequence classification/regression head.
816
 
817
  This head is just a linear layer on top of the pooled output. Used for,
 
25
  index_put_first_axis, pad_input,
26
  unpad_input, unpad_input_only)
27
 
28
+ from .configuration_bert import BertConfig
29
+
30
+
31
  try:
32
  from .flash_attn_triton import flash_attn_qkvpacked_func
33
  except ImportError as e:
 
686
 
687
 
688
  class BertForMaskedLM(BertPreTrainedModel):
689
+
690
+ config_class = BertConfig
691
 
692
  def __init__(self, config):
693
  super().__init__(config)
 
817
 
818
 
819
  class BertForSequenceClassification(BertPreTrainedModel):
820
+
821
+ config_class = BertConfig
822
+
823
  """Bert Model transformer with a sequence classification/regression head.
824
 
825
  This head is just a linear layer on top of the pooled output. Used for,