farzadab commited on
Commit
0cde4b8
·
verified ·
1 Parent(s): 9c92efc

Update ultravox_config.py

Browse files
Files changed (1) hide show
  1. ultravox_config.py +22 -8
ultravox_config.py CHANGED
@@ -19,6 +19,8 @@ class LoraConfigSimplified:
19
  target_modules: Optional[List[str]] = dataclasses.field(
20
  default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
21
  )
 
 
22
 
23
 
24
  class LossFunction(str, Enum):
@@ -28,8 +30,10 @@ class LossFunction(str, Enum):
28
 
29
  @dataclasses.dataclass
30
  class LossConfig:
31
- loss_function: LossFunction = LossFunction.KL_Divergence
32
  kl_temperature: float = 2.0
 
 
33
 
34
  @property
35
  def requires_alt_fields(self):
@@ -45,7 +49,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
45
  documentation from [`PretrainedConfig`] for more information.
46
 
47
  Args:
48
- audio_config (`Wav2Vec2Config`, *optional*):
49
  Custom audio config or dict
50
  text_config (`Union[AutoConfig, dict]`, *optional*):
51
  The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
@@ -63,15 +67,17 @@ class UltravoxConfig(transformers.PretrainedConfig):
63
  The LoRA configuration for finetuning the text model.
64
  audio_model_lora_config (`LoraConfigSimplified`, *optional*):
65
  The LoRA configuration for finetuning the audio model.
 
 
66
 
67
 
68
  Example:
69
 
70
  ```python
71
- >>> from transformers import UltravoxForConditionalGeneration, Wav2Vec2Config, UltravoxConfig, LlamaConfig
72
 
73
  >>> # Initializing an audio encoder config
74
- >>> audio_config = Wav2Vec2Config()
75
 
76
  >>> # Initializing a Llama config
77
  >>> text_config = LlamaConfig()
@@ -80,13 +86,13 @@ class UltravoxConfig(transformers.PretrainedConfig):
80
  >>> configuration = UltravoxConfig(audio_config, text_config)
81
 
82
  >>> # Initializing a completely untrained model from the configuration
83
- >>> model = UltravoxForConditionalGeneration(configuration)
84
 
85
  >>> # Accessing the model configuration
86
  >>> configuration = model.config
87
 
88
  >>> # Initialize a model from pretrained checkpoints and random projector weights
89
- >>> config = UltravoxConfig(audio_model_id="facebook/wav2vec2-base-960h", text_model_id="meta-llama/Llama-2-7b-chat-hf")
90
  ```"""
91
 
92
  model_type = "ultravox"
@@ -103,8 +109,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
103
  stack_factor: int = 8,
104
  norm_init: float = 0.4,
105
  projector_act: str = "swiglu",
 
106
  text_model_lora_config: Optional[LoraConfigSimplified] = None,
107
  audio_model_lora_config: Optional[LoraConfigSimplified] = None,
 
108
  **kwargs,
109
  ):
110
  self.ignore_index = ignore_index
@@ -116,7 +124,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
116
  self.stack_factor = stack_factor
117
  self.norm_init = norm_init
118
  self.projector_act = projector_act
119
-
120
  if text_model_id is not None:
121
  self.text_config: transformers.LlamaConfig = (
122
  transformers.AutoConfig.from_pretrained(text_model_id)
@@ -134,7 +142,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
134
  else:
135
  audio_config = audio_config or {}
136
  self.audio_config = transformers.CONFIG_MAPPING[
137
- audio_config.get("model_type", "wav2vec2")
138
  ](**audio_config)
139
 
140
  self.text_model_lora_config = (
@@ -147,6 +155,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
147
  if isinstance(audio_model_lora_config, dict)
148
  else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
149
  )
 
150
 
151
  self.vocab_size = self.text_config.vocab_size
152
 
@@ -160,7 +169,12 @@ class UltravoxConfig(transformers.PretrainedConfig):
160
  # remove text_config and audio_config if text_model_id and audio_model_id are present
161
  if self.text_model_id is not None:
162
  diff_dict.pop("text_config", None)
 
 
 
163
  if self.audio_model_id is not None:
164
  diff_dict.pop("audio_config", None)
 
 
165
 
166
  return diff_dict
 
19
  target_modules: Optional[List[str]] = dataclasses.field(
20
  default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
21
  )
22
+ # A list of module names regex patterns to unfreeze. Only used if r == 0.
23
+ unfreeze_layers: Optional[List[str]] = None
24
 
25
 
26
  class LossFunction(str, Enum):
 
30
 
31
  @dataclasses.dataclass
32
  class LossConfig:
33
+ loss_function: LossFunction = LossFunction.CrossEntropy
34
  kl_temperature: float = 2.0
35
+ # Number of tokens to ignore from the beginning of the sequence. Only used in LSM
36
+ initial_tokens_to_ignore: int = 0
37
 
38
  @property
39
  def requires_alt_fields(self):
 
49
  documentation from [`PretrainedConfig`] for more information.
50
 
51
  Args:
52
+ audio_config (`WhisperConfig`, *optional*):
53
  Custom audio config or dict
54
  text_config (`Union[AutoConfig, dict]`, *optional*):
55
  The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
 
67
  The LoRA configuration for finetuning the text model.
68
  audio_model_lora_config (`LoraConfigSimplified`, *optional*):
69
  The LoRA configuration for finetuning the audio model.
70
+ audio_latency_block_size (`int`, *optional*, defaults to `None`):
71
+ The latency block size for simulating audio streaming.
72
 
73
 
74
  Example:
75
 
76
  ```python
77
+ >>> from transformers import UltravoxModel, WhisperConfig, UltravoxConfig, LlamaConfig
78
 
79
  >>> # Initializing an audio encoder config
80
+ >>> audio_config = WhisperConfig()
81
 
82
  >>> # Initializing a Llama config
83
  >>> text_config = LlamaConfig()
 
86
  >>> configuration = UltravoxConfig(audio_config, text_config)
87
 
88
  >>> # Initializing a completely untrained model from the configuration
89
+ >>> model = UltravoxModel(configuration)
90
 
91
  >>> # Accessing the model configuration
92
  >>> configuration = model.config
93
 
94
  >>> # Initialize a model from pretrained checkpoints and random projector weights
95
+ >>> config = UltravoxConfig(audio_model_id="openai/whisper-tiny", text_model_id="meta-llama/Llama-2-7b-chat-hf")
96
  ```"""
97
 
98
  model_type = "ultravox"
 
109
  stack_factor: int = 8,
110
  norm_init: float = 0.4,
111
  projector_act: str = "swiglu",
112
+ projector_ln_mid: bool = False, # defaults to False for compatibility with v0.4.1 and below
113
  text_model_lora_config: Optional[LoraConfigSimplified] = None,
114
  audio_model_lora_config: Optional[LoraConfigSimplified] = None,
115
+ audio_latency_block_size: Optional[int] = None,
116
  **kwargs,
117
  ):
118
  self.ignore_index = ignore_index
 
124
  self.stack_factor = stack_factor
125
  self.norm_init = norm_init
126
  self.projector_act = projector_act
127
+ self.projector_ln_mid = projector_ln_mid
128
  if text_model_id is not None:
129
  self.text_config: transformers.LlamaConfig = (
130
  transformers.AutoConfig.from_pretrained(text_model_id)
 
142
  else:
143
  audio_config = audio_config or {}
144
  self.audio_config = transformers.CONFIG_MAPPING[
145
+ audio_config.get("model_type", "whisper")
146
  ](**audio_config)
147
 
148
  self.text_model_lora_config = (
 
155
  if isinstance(audio_model_lora_config, dict)
156
  else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
157
  )
158
+ self.audio_latency_block_size = audio_latency_block_size
159
 
160
  self.vocab_size = self.text_config.vocab_size
161
 
 
169
  # remove text_config and audio_config if text_model_id and audio_model_id are present
170
  if self.text_model_id is not None:
171
  diff_dict.pop("text_config", None)
172
+ elif "text_config" in diff_dict:
173
+ diff_dict["text_config"].pop("_attn_implementation_autoset", None)
174
+
175
  if self.audio_model_id is not None:
176
  diff_dict.pop("audio_config", None)
177
+ elif "audio_config" in diff_dict:
178
+ diff_dict["audio_config"].pop("_attn_implementation_autoset", None)
179
 
180
  return diff_dict