Skip to content

Commit 71f89c5

Browse files
committed
yapf
Signed-off-by: Randall Smith <[email protected]>
1 parent 4ef102d commit 71f89c5

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

vllm/attention/ops/triton_flash_attention.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
2828

29+
2930
class MetaData:
3031
cu_seqlens_q = None
3132
cu_seqlens_k = None
@@ -261,7 +262,7 @@ def _attn_fwd_inner(
261262
# We start from end of seqlen_k so only the first iteration would need
262263
# to be checked for padding if it is not a multiple of block_n
263264
# TODO: This can be optimized to only be true for the padded block.
264-
if MASK_STEPS: # noqa SIM102
265+
if MASK_STEPS: # noqa SIM102
265266
# If this is the last block / iteration, we want to
266267
# mask if the sequence length is not a multiple of block size
267268
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if
@@ -633,8 +634,8 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz,
633634
# we subtract this from qk which makes it -inf, such that
634635
# exp(qk - inf) = 0 for these masked blocks.
635636
l_value = tl.full([BLOCK_M],
636-
value=float("inf"),
637-
dtype=tl.float32)
637+
value=float("inf"),
638+
dtype=tl.float32)
638639
l_ptrs_mask = offs_m < MAX_SEQLENS_Q
639640
tl.store(l_ptrs, l_value, mask=l_ptrs_mask)
640641
# TODO: Should dropout and return encoded softmax be
@@ -855,7 +856,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz,
855856
start_m_idx = start_m * BLOCK_M
856857
causal_start_idx = seqlen_q - seqlen_k
857858
acc = acc.to(Out.type.element_ty)
858-
if IS_CAUSAL: # noqa: SIM102
859+
if IS_CAUSAL: # noqa: SIM102
859860
if (causal_start_idx > start_m_idx
860861
and causal_start_idx < end_m_idx):
861862
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
@@ -1358,6 +1359,7 @@ def _attn_bwd(
13581359

13591360
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
13601361

1362+
13611363
def get_shape_from_layout(q, k, metadata):
13621364
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
13631365
if metadata.layout == 'thd':

0 commit comments

Comments
 (0)