Make compatible with recent versions of triton
Browse filesMany people are complaining about not being able to run this model with the latest versions of triton, e.g. [here](https://github.com/Zhihan1996/DNABERT_2/issues/19#issuecomment-1641944919). This patch solves this issue.
- flash_attn_triton.py +2 -2
flash_attn_triton.py
CHANGED
@@ -188,7 +188,7 @@ def _fwd_kernel(
|
|
188 |
(offs_d[None, :] < headdim),
|
189 |
other=0.0)
|
190 |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
191 |
-
qk += tl.dot(q, k
|
192 |
# Trying to combine the two masks seem to make the result wrong
|
193 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
194 |
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
|
@@ -431,7 +431,7 @@ def _bwd_kernel_one_col_block(
|
|
431 |
(offs_d[None, :] < headdim),
|
432 |
other=0.0)
|
433 |
# recompute p = softmax(qk, dim=-1).T
|
434 |
-
qk = tl.dot(q, k
|
435 |
# Trying to combine the two masks seem to make the result wrong
|
436 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
437 |
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
|
|
|
188 |
(offs_d[None, :] < headdim),
|
189 |
other=0.0)
|
190 |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
191 |
+
qk += tl.dot(q, tl.trans(k))
|
192 |
# Trying to combine the two masks seem to make the result wrong
|
193 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
194 |
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
|
|
|
431 |
(offs_d[None, :] < headdim),
|
432 |
other=0.0)
|
433 |
# recompute p = softmax(qk, dim=-1).T
|
434 |
+
qk = tl.dot(q, tl.trans(k))
|
435 |
# Trying to combine the two masks seem to make the result wrong
|
436 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
437 |
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
|