KingRei commited on
Commit
030db8a
·
verified ·
1 Parent(s): 669d633

Update modeling_mpt.py

Browse files

In PyTorch 2.7, the unpadding_function() returns five outputs. Therefore, assigning its result to four variables, as in (_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(...), will result in an error due to mismatched unpacking. To resolve this, you can modify the assignment to include an additional variable that captures the extra output, like so: (_, indices_q, cu_seqlens_q, max_seqlen_q, *rest) = unpadding_function(...). This approach ensures that all returned values are properly accounted for, preventing unpacking errors. The same adjustment applies when unpacking outputs for k and v.

Files changed (1) hide show
  1. modeling_mpt.py +3 -3
modeling_mpt.py CHANGED
@@ -140,9 +140,9 @@ def gen_flash_attn_padding_info(bsz: int, S: int, past_key_len: int, device: tor
140
  key_padding_mask = attention_mask_in_length
141
  query_padding_mask = attention_mask_in_length
142
  unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
143
- (_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
144
- (_, indices_k, cu_seqlens_k, max_seqlen_k) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
145
- (_, indices_v, _, _) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
146
  flash_attn_padding_info['indices_q'] = indices_q
147
  flash_attn_padding_info['indices_k'] = indices_k
148
  flash_attn_padding_info['indices_v'] = indices_v
 
140
  key_padding_mask = attention_mask_in_length
141
  query_padding_mask = attention_mask_in_length
142
  unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
143
+ (_, indices_q, cu_seqlens_q, max_seqlen_q, *rest) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
144
+ (_, indices_k, cu_seqlens_k, max_seqlen_k, *rest) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
145
+ (_, indices_v, _, _, *rest) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
146
  flash_attn_padding_info['indices_q'] = indices_q
147
  flash_attn_padding_info['indices_k'] = indices_k
148
  flash_attn_padding_info['indices_v'] = indices_v