from typing import Optional from transformers import Qwen2Config from transformers.configuration_utils import PretrainedConfig class StepAudio2EncoderConfig(PretrainedConfig): model_type = "step_audio_2_encoder" def __init__( self, n_mels=128, n_audio_ctx=1500, n_audio_state=512, n_audio_head=8, n_audio_layer=6, llm_dim=4096, kernel_size=3, adapter_stride=2, **kwargs, ): self.n_mels = n_mels self.n_audio_ctx = n_audio_ctx self.n_audio_state = n_audio_state self.n_audio_head = n_audio_head self.n_audio_layer = n_audio_layer self.llm_dim = llm_dim self.kernel_size = kernel_size self.adapter_stride = adapter_stride super().__init__(**kwargs) class StepAudio2Config(PretrainedConfig): model_type = "step_audio_2" architectures = ["StepAudio2ForCausalLM"] def __init__( self, audio_encoder_config=None, use_sliding_window: bool = False, sliding_window: Optional[int] = 2048, max_window_layers: Optional[int] = None, **kwargs ): kwargs.setdefault("use_sliding_window", use_sliding_window) kwargs.setdefault("sliding_window", sliding_window) if max_window_layers is None: max_window_layers = kwargs.get("num_hidden_layers", None) kwargs.setdefault("max_window_layers", max_window_layers) super().__init__(**kwargs) self.text_config = Qwen2Config(**kwargs) if audio_encoder_config is None: self.audio_encoder_config = StepAudio2EncoderConfig() elif isinstance(audio_encoder_config, dict): self.audio_encoder_config = StepAudio2EncoderConfig(**audio_encoder_config)