|
26 | 26 |
|
27 | 27 | SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] |
28 | 28 |
|
| 29 | + |
29 | 30 | class MetaData: |
30 | 31 | cu_seqlens_q = None |
31 | 32 | cu_seqlens_k = None |
@@ -261,7 +262,7 @@ def _attn_fwd_inner( |
261 | 262 | # We start from end of seqlen_k so only the first iteration would need |
262 | 263 | # to be checked for padding if it is not a multiple of block_n |
263 | 264 | # TODO: This can be optimized to only be true for the padded block. |
264 | | - if MASK_STEPS: # noqa SIM102 |
| 265 | + if MASK_STEPS: # noqa SIM102 |
265 | 266 | # If this is the last block / iteration, we want to |
266 | 267 | # mask if the sequence length is not a multiple of block size |
267 | 268 | # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if |
@@ -633,8 +634,8 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, |
633 | 634 | # we subtract this from qk which makes it -inf, such that |
634 | 635 | # exp(qk - inf) = 0 for these masked blocks. |
635 | 636 | l_value = tl.full([BLOCK_M], |
636 | | - value=float("inf"), |
637 | | - dtype=tl.float32) |
| 637 | + value=float("inf"), |
| 638 | + dtype=tl.float32) |
638 | 639 | l_ptrs_mask = offs_m < MAX_SEQLENS_Q |
639 | 640 | tl.store(l_ptrs, l_value, mask=l_ptrs_mask) |
640 | 641 | # TODO: Should dropout and return encoded softmax be |
@@ -855,7 +856,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, |
855 | 856 | start_m_idx = start_m * BLOCK_M |
856 | 857 | causal_start_idx = seqlen_q - seqlen_k |
857 | 858 | acc = acc.to(Out.type.element_ty) |
858 | | - if IS_CAUSAL: # noqa: SIM102 |
| 859 | + if IS_CAUSAL: # noqa: SIM102 |
859 | 860 | if (causal_start_idx > start_m_idx |
860 | 861 | and causal_start_idx < end_m_idx): |
861 | 862 | out_mask_boundary = tl.full((BLOCK_DMODEL, ), |
@@ -1358,6 +1359,7 @@ def _attn_bwd( |
1358 | 1359 |
|
1359 | 1360 | SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] |
1360 | 1361 |
|
| 1362 | + |
1361 | 1363 | def get_shape_from_layout(q, k, metadata): |
1362 | 1364 | assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." |
1363 | 1365 | if metadata.layout == 'thd': |
|
0 commit comments