File size: 1,810 Bytes
69690cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import Optional

from transformers import Qwen2Config
from transformers.configuration_utils import PretrainedConfig


class StepAudio2EncoderConfig(PretrainedConfig):
    model_type = "step_audio_2_encoder"

    def __init__(
        self,
        n_mels=128,
        n_audio_ctx=1500,
        n_audio_state=512,
        n_audio_head=8,
        n_audio_layer=6,
        llm_dim=4096,
        kernel_size=3,
        adapter_stride=2,
        **kwargs,
    ):
        self.n_mels      = n_mels
        self.n_audio_ctx = n_audio_ctx
        self.n_audio_state = n_audio_state
        self.n_audio_head = n_audio_head
        self.n_audio_layer = n_audio_layer
        self.llm_dim     = llm_dim
        self.kernel_size = kernel_size
        self.adapter_stride = adapter_stride
        super().__init__(**kwargs)

class StepAudio2Config(PretrainedConfig):
    model_type = "step_audio_2"
    architectures = ["StepAudio2ForCausalLM"]

    def __init__(
        self,
        audio_encoder_config=None,
        use_sliding_window: bool = False,
        sliding_window: Optional[int] = 2048,
        max_window_layers: Optional[int] = None,
        **kwargs
    ):
        kwargs.setdefault("use_sliding_window", use_sliding_window)
        kwargs.setdefault("sliding_window", sliding_window)
        if max_window_layers is None:
            max_window_layers = kwargs.get("num_hidden_layers", None)
        kwargs.setdefault("max_window_layers", max_window_layers)
        super().__init__(**kwargs)

        self.text_config = Qwen2Config(**kwargs)

        if audio_encoder_config is None:
            self.audio_encoder_config = StepAudio2EncoderConfig()
        elif isinstance(audio_encoder_config, dict):
            self.audio_encoder_config = StepAudio2EncoderConfig(**audio_encoder_config)