Update modeling_baichuan.py
Browse files- modeling_baichuan.py +1 -0
modeling_baichuan.py
CHANGED
|
@@ -181,6 +181,7 @@ class BaichuanAttention(torch.nn.Module):
|
|
| 181 |
# )
|
| 182 |
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
| 183 |
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
|
|
|
|
| 184 |
else:
|
| 185 |
attn_weights = torch.matmul(
|
| 186 |
query_states, key_states.transpose(2, 3)
|
|
|
|
| 181 |
# )
|
| 182 |
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
| 183 |
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
|
| 184 |
+
attn_output = attn_output.transpose(1, 2)
|
| 185 |
else:
|
| 186 |
attn_weights = torch.matmul(
|
| 187 |
query_states, key_states.transpose(2, 3)
|