fix bug
Browse files- modeling_yi.py +1 -1
modeling_yi.py
CHANGED
|
@@ -539,7 +539,7 @@ class YiModel(YiPreTrainedModel):
|
|
| 539 |
def _prepare_decoder_attention_mask(
|
| 540 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
| 541 |
):
|
| 542 |
-
input_shape = input_ids.shape if input_ids else inputs_embeds.shape[:-1]
|
| 543 |
# create causal mask
|
| 544 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 545 |
combined_attention_mask = None
|
|
|
|
| 539 |
def _prepare_decoder_attention_mask(
|
| 540 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
| 541 |
):
|
| 542 |
+
input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
|
| 543 |
# create causal mask
|
| 544 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 545 |
combined_attention_mask = None
|