petronny commited on
Commit
6da8a9a
·
verified ·
1 Parent(s): 9502fdb

updated config.json

Browse files
config.json CHANGED
@@ -7,21 +7,23 @@
7
  "AutoModelForCausalLM": "modeling_step_audio_2.StepAudio2ForCausalLM"
8
  },
9
  "model_type": "step_audio_2",
10
- "hidden_size": 3584,
11
- "intermediate_size": 18944,
12
- "num_attention_heads": 28,
13
- "num_attention_groups": 4,
14
- "num_key_value_heads": 4,
15
- "num_hidden_layers": 28,
16
- "max_seq_len": 16384,
17
- "vocab_size": 158720,
18
- "rms_norm_eps": 1e-06,
19
- "eos_token_id": 151643,
20
- "pad_token_id": 151643,
21
- "rope_theta": 1000000.0,
22
- "max_position_embeddings": 16384,
23
- "rope_scaling": null,
24
- "torch_dtype": "bfloat16",
 
 
25
  "audio_encoder_config": {
26
  "n_mels": 128,
27
  "n_audio_ctx": 1500,
 
7
  "AutoModelForCausalLM": "modeling_step_audio_2.StepAudio2ForCausalLM"
8
  },
9
  "model_type": "step_audio_2",
10
+ "text_config": {
11
+ "hidden_size": 3584,
12
+ "intermediate_size": 18944,
13
+ "num_attention_heads": 28,
14
+ "num_attention_groups": 4,
15
+ "num_key_value_heads": 4,
16
+ "num_hidden_layers": 28,
17
+ "max_seq_len": 16384,
18
+ "vocab_size": 158720,
19
+ "rms_norm_eps": 1e-06,
20
+ "eos_token_id": 151643,
21
+ "pad_token_id": 151643,
22
+ "rope_theta": 1000000.0,
23
+ "max_position_embeddings": 16384,
24
+ "rope_scaling": null,
25
+ "torch_dtype": "bfloat16"
26
+ },
27
  "audio_encoder_config": {
28
  "n_mels": 128,
29
  "n_audio_ctx": 1500,
configuration_step_audio_2.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  from transformers import Qwen2Config
4
  from transformers.configuration_utils import PretrainedConfig
@@ -29,13 +29,80 @@ class StepAudio2EncoderConfig(PretrainedConfig):
29
  self.adapter_stride = adapter_stride
30
  super().__init__(**kwargs)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class StepAudio2Config(PretrainedConfig):
33
  model_type = "step_audio_2"
34
  architectures = ["StepAudio2ForCausalLM"]
35
 
36
  def __init__(
37
  self,
38
- audio_encoder_config=None,
 
39
  use_sliding_window: bool = False,
40
  sliding_window: Optional[int] = 2048,
41
  max_window_layers: Optional[int] = None,
@@ -48,7 +115,12 @@ class StepAudio2Config(PretrainedConfig):
48
  kwargs.setdefault("max_window_layers", max_window_layers)
49
  super().__init__(**kwargs)
50
 
51
- self.text_config = Qwen2Config(**kwargs)
 
 
 
 
 
52
 
53
  if audio_encoder_config is None:
54
  self.audio_encoder_config = StepAudio2EncoderConfig()
 
1
+ from typing import Optional, Union
2
 
3
  from transformers import Qwen2Config
4
  from transformers.configuration_utils import PretrainedConfig
 
29
  self.adapter_stride = adapter_stride
30
  super().__init__(**kwargs)
31
 
32
+ class StepAudio2TextConfig(PretrainedConfig):
33
+ model_type = "step_audio_2_text"
34
+
35
+ def __init__(
36
+ self,
37
+ vocab_size=64012,
38
+ hidden_size=4096,
39
+ intermediate_size=11008,
40
+ num_hidden_layers=48,
41
+ num_attention_heads=32,
42
+ num_attention_groups=4,
43
+ num_key_value_heads=4,
44
+ hidden_act="silu",
45
+ max_position_embeddings=8192,
46
+ initializer_range=0.02,
47
+ rms_norm_eps=1e-6,
48
+ rope_theta=1000000.0,
49
+ rope_scaling=None,
50
+ eos_token_id=None,
51
+ **kwargs
52
+ ):
53
+
54
+ if eos_token_id is not None:
55
+ if isinstance(eos_token_id, list):
56
+ eos_token_id = list(set([151643, 151645, 151665] + eos_token_id))
57
+ else:
58
+ eos_token_id = [151643, 151645, 151665, eos_token_id]
59
+ else:
60
+ eos_token_id = [151643, 151645, 151665]
61
+
62
+ super().__init__(
63
+ eos_token_id=eos_token_id,
64
+ **kwargs)
65
+
66
+ self.vocab_size = vocab_size
67
+ self.hidden_size = hidden_size
68
+ self.intermediate_size = intermediate_size
69
+ self.num_hidden_layers = num_hidden_layers
70
+ self.num_attention_heads = num_attention_heads
71
+ self.num_attention_groups = num_attention_groups
72
+ self.num_key_value_heads = num_key_value_heads
73
+ assert self.num_attention_groups == self.num_key_value_heads, "num_attention_groups must be equal to num_key_value_heads"
74
+ self.hidden_act = hidden_act
75
+ self.max_position_embeddings = max_position_embeddings
76
+ self.initializer_range = initializer_range
77
+ self.rms_norm_eps = rms_norm_eps
78
+ self.rope_theta = rope_theta
79
+ self.rope_scaling = rope_scaling
80
+
81
+ self.text_config = Qwen2Config(
82
+ vocab_size=vocab_size,
83
+ hidden_size=hidden_size,
84
+ intermediate_size=intermediate_size,
85
+ num_hidden_layers=num_hidden_layers,
86
+ num_attention_heads=num_attention_heads,
87
+ num_key_value_heads=num_key_value_heads,
88
+ hidden_act=hidden_act,
89
+ max_position_embeddings=max_position_embeddings,
90
+ initializer_range=initializer_range,
91
+ rms_norm_eps=rms_norm_eps,
92
+ rope_theta=rope_theta,
93
+ rope_scaling=rope_scaling,
94
+ architectures=["Qwen2ForCausalLM"],
95
+ torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
96
+ )
97
+
98
  class StepAudio2Config(PretrainedConfig):
99
  model_type = "step_audio_2"
100
  architectures = ["StepAudio2ForCausalLM"]
101
 
102
  def __init__(
103
  self,
104
+ audio_encoder_config :Optional[Union[dict, StepAudio2EncoderConfig]] = None,
105
+ text_config: Optional[Union[dict, StepAudio2TextConfig]] = None,
106
  use_sliding_window: bool = False,
107
  sliding_window: Optional[int] = 2048,
108
  max_window_layers: Optional[int] = None,
 
115
  kwargs.setdefault("max_window_layers", max_window_layers)
116
  super().__init__(**kwargs)
117
 
118
+ if text_config is None:
119
+ text_config = StepAudio2TextConfig().text_config
120
+ elif isinstance(text_config, dict):
121
+ text_config = StepAudio2TextConfig(**text_config).text_config
122
+
123
+ self.text_config = text_config
124
 
125
  if audio_encoder_config is None:
126
  self.audio_encoder_config = StepAudio2EncoderConfig()
modeling_step_audio_2.py CHANGED
@@ -328,8 +328,8 @@ class StepAudio2ForCausalLM(PreTrainedModel, GenerationMixin):
328
  self.encoder = self.encoder.bfloat16()
329
  self.adapter = self.adapter.bfloat16()
330
  self.lm_head = torch.nn.Linear(
331
- config.hidden_size,
332
- config.vocab_size,
333
  bias=False,
334
  dtype=dtype
335
  )
 
328
  self.encoder = self.encoder.bfloat16()
329
  self.adapter = self.adapter.bfloat16()
330
  self.lm_head = torch.nn.Linear(
331
+ config.text_config.hidden_size,
332
+ config.text_config.vocab_size,
333
  bias=False,
334
  dtype=dtype
335
  )