Update modeling_mpt.py
Browse filesIn PyTorch 2.7, the unpadding_function() returns five outputs. Therefore, assigning its result to four variables, as in (_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(...), will result in an error due to mismatched unpacking. To resolve this, you can modify the assignment to include an additional variable that captures the extra output, like so: (_, indices_q, cu_seqlens_q, max_seqlen_q, *rest) = unpadding_function(...). This approach ensures that all returned values are properly accounted for, preventing unpacking errors. The same adjustment applies when unpacking outputs for k and v.
- modeling_mpt.py +3 -3
modeling_mpt.py
CHANGED
@@ -140,9 +140,9 @@ def gen_flash_attn_padding_info(bsz: int, S: int, past_key_len: int, device: tor
|
|
140 |
key_padding_mask = attention_mask_in_length
|
141 |
query_padding_mask = attention_mask_in_length
|
142 |
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
|
143 |
-
(_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
|
144 |
-
(_, indices_k, cu_seqlens_k, max_seqlen_k) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
145 |
-
(_, indices_v, _, _) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
146 |
flash_attn_padding_info['indices_q'] = indices_q
|
147 |
flash_attn_padding_info['indices_k'] = indices_k
|
148 |
flash_attn_padding_info['indices_v'] = indices_v
|
|
|
140 |
key_padding_mask = attention_mask_in_length
|
141 |
query_padding_mask = attention_mask_in_length
|
142 |
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
|
143 |
+
(_, indices_q, cu_seqlens_q, max_seqlen_q, *rest) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
|
144 |
+
(_, indices_k, cu_seqlens_k, max_seqlen_k, *rest) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
145 |
+
(_, indices_v, _, _, *rest) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
146 |
flash_attn_padding_info['indices_q'] = indices_q
|
147 |
flash_attn_padding_info['indices_k'] = indices_k
|
148 |
flash_attn_padding_info['indices_v'] = indices_v
|