Update modeling_baichuan.py
Browse files- modeling_baichuan.py +1 -1
modeling_baichuan.py
CHANGED
|
@@ -177,7 +177,7 @@ class BaichuanAttention(torch.nn.Module):
|
|
| 177 |
key_states = key_states.transpose(1, 2)
|
| 178 |
value_states = value_states.transpose(1, 2)
|
| 179 |
attn_output = xops.memory_efficient_attention(
|
| 180 |
-
query_states, key_states, value_states, attn_bias=attention_mask
|
| 181 |
)
|
| 182 |
else:
|
| 183 |
attn_weights = torch.matmul(
|
|
|
|
| 177 |
key_states = key_states.transpose(1, 2)
|
| 178 |
value_states = value_states.transpose(1, 2)
|
| 179 |
attn_output = xops.memory_efficient_attention(
|
| 180 |
+
query_states, key_states, value_states, attn_bias=attention_mask
|
| 181 |
)
|
| 182 |
else:
|
| 183 |
attn_weights = torch.matmul(
|