Skip to content

Commit 8525785

Browse files
committed
Update gluon_attention_persistent_forward
The implementation has changed in upstream Triton, update it.
1 parent fe646da commit 8525785

File tree

1 file changed

+39
-172
lines changed

1 file changed

+39
-172
lines changed

tritonbench/kernels/gluon_attention_persistent_forward.py

Lines changed: 39 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from triton.experimental.gluon import language as gl
99
from triton.experimental.gluon.language.nvidia.blackwell import (
1010
allocate_tensor_memory,
11+
float2,
1112
get_tmem_32x32b_reg_layout,
1213
mbarrier,
1314
tcgen05_commit,
@@ -243,9 +244,7 @@ class AttentionConfig:
243244
alpha_2d_layout: gl.constexpr
244245

245246
num_kv_buffers: gl.constexpr
246-
use_fadd2_reduce: gl.constexpr
247247
use_exp2_turnstile: gl.constexpr
248-
use_ffma2_scale_rowmax: gl.constexpr
249248

250249
def __init__(
251250
self,
@@ -290,13 +289,13 @@ def __init__(
290289
qk_instr_shape = get_mma_instr_shape(self.qk_shape, gl.float32)
291290
o_instr_shape = get_mma_instr_shape(self.o_shape, gl.float32)
292291
self.qk_tmem_layout = gl.constexpr(
293-
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), unpacked=True)
292+
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1)
294293
)
295294
self.o_tmem_layout = gl.constexpr(
296-
TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), unpacked=True)
295+
TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), col_stride=1)
297296
)
298297
self.p_tmem_layout = gl.constexpr(
299-
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), unpacked=False)
298+
TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1)
300299
)
301300

302301
self.qk_layout = gl.constexpr(
@@ -321,17 +320,13 @@ def __init__(
321320
gl.BlockedLayout([1, 1], [32, 1], [self.num_warps, 1], [0, 1])
322321
)
323322

324-
is_fp16 = dtype.value in [gl.float16, gl.bfloat16]
323+
is_fp16 = self.dtype.value in [gl.float16, gl.bfloat16]
325324
if is_fp16:
326325
self.num_kv_buffers = gl.constexpr(3 if HEAD_DIM == 128 else 6)
327326
else:
328327
self.num_kv_buffers = gl.constexpr(4 if HEAD_DIM == 128 else 8)
329328

330-
self.use_fadd2_reduce = gl.constexpr(HEAD_DIM == 64)
331329
self.use_exp2_turnstile = gl.constexpr(HEAD_DIM == 64)
332-
self.use_ffma2_scale_rowmax = gl.constexpr(
333-
HEAD_DIM == 128 or is_fp16 == (STAGE == 3)
334-
)
335330

336331
@gluon.jit
337332
def get_program(self, pid_m, pid_n):
@@ -421,113 +416,6 @@ def get_loop_bounds(self, STAGE: gl.constexpr):
421416
return lo, hi
422417

423418

424-
# ===-----------------------------------------------------------------------===#
425-
# float2
426-
# ===-----------------------------------------------------------------------===#
427-
428-
429-
@gluon.jit
430-
def _add_f32x2(a, b):
431-
return gl.inline_asm_elementwise(
432-
"""
433-
{
434-
.reg .b64 ra, rb, rc;
435-
mov.b64 ra, { $2, $3 };
436-
mov.b64 rb, { $4, $5 };
437-
add.f32x2 rc, ra, rb;
438-
mov.b64 { $0, $1 }, rc;
439-
}
440-
""",
441-
"=r,=r,r,r,r,r",
442-
[a, b],
443-
dtype=gl.float32,
444-
is_pure=True,
445-
pack=2,
446-
)
447-
448-
449-
@gluon.jit
450-
def _mul_f32x2(a, b):
451-
return gl.inline_asm_elementwise(
452-
"""
453-
{
454-
.reg .b64 ra, rb, rc;
455-
mov.b64 ra, { $2, $3 };
456-
mov.b64 rb, { $4, $5 };
457-
mul.f32x2 rc, ra, rb;
458-
mov.b64 { $0, $1 }, rc;
459-
}
460-
""",
461-
"=r,=r,r,r,r,r",
462-
[a, b],
463-
dtype=gl.float32,
464-
is_pure=True,
465-
pack=2,
466-
)
467-
468-
469-
@gluon.jit
470-
def _fma_f32x2(a, b, c):
471-
return gl.inline_asm_elementwise(
472-
"""
473-
{
474-
.reg .b64 ra, rb, rc, rd;
475-
mov.b64 ra, { $2, $3 };
476-
mov.b64 rb, { $4, $5 };
477-
mov.b64 rc, { $6, $7 };
478-
fma.rn.f32x2 rd, ra, rb, rc;
479-
mov.b64 { $0, $1 }, rd;
480-
}
481-
""",
482-
"=r,=r,r,r,r,r,r,r",
483-
[a, b, c],
484-
dtype=gl.float32,
485-
is_pure=True,
486-
pack=2,
487-
)
488-
489-
490-
@gluon.jit
491-
def _reduce_fadd2(p0a, p1a, p0b, p1b):
492-
return gl.inline_asm_elementwise(
493-
"""
494-
{
495-
.reg .b64 rc, ra, rb;
496-
mov.b64 ra, { $2, $4 };
497-
mov.b64 rb, { $3, $5 };
498-
add.f32x2 rc, ra, rb;
499-
mov.b64 { $0, $1 }, rc;
500-
}
501-
""",
502-
"=r,=r,r,r,r,r",
503-
[p0a, p0b, p1a, p1b],
504-
dtype=[gl.float32, gl.float32],
505-
is_pure=True,
506-
pack=1,
507-
)
508-
509-
510-
@gluon.jit
511-
def _pairwise_fma_f32x2(a0, b0, c0, a1, b1, c1):
512-
return gl.inline_asm_elementwise(
513-
"""
514-
{
515-
.reg .b64 rd, ra, rb, rc;
516-
mov.b64 ra, { $2, $5 };
517-
mov.b64 rb, { $3, $6 };
518-
mov.b64 rc, { $4, $7 };
519-
fma.rn.f32x2 rd, ra, rb, rc;
520-
mov.b64 { $0, $1 }, rd;
521-
}
522-
""",
523-
"=r,=r,r,r,r,r,r,r",
524-
[a0, b0, c0, a1, b1, c1],
525-
dtype=[gl.float32, gl.float32],
526-
is_pure=True,
527-
pack=1,
528-
)
529-
530-
531419
# ===-----------------------------------------------------------------------===#
532420
# _gluon_attn
533421
# ===-----------------------------------------------------------------------===#
@@ -542,15 +430,15 @@ def _borrow_s_as_p(config, s_tmem):
542430
@gluon.jit
543431
def _borrow_s_as_alpha(config, s_tmem):
544432
alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1)
545-
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
433+
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
546434
return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout)
547435

548436

549437
@gluon.jit
550438
def _borrow_s_for_epilogue(config, s_tmem):
551439
m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
552440
l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
553-
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
441+
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
554442
m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
555443
l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
556444
return m_i_tmem, l_i_tmem
@@ -798,8 +686,7 @@ def _softmax_inner_loop(
798686
corr_bar, #
799687
offs_m,
800688
m_i,
801-
l_i0,
802-
l_i1,
689+
l_i,
803690
STAGE: gl.constexpr,
804691
):
805692
lo, hi = prog.get_loop_bounds(STAGE)
@@ -821,11 +708,10 @@ def _softmax_inner_loop(
821708
)
822709
mbarrier.arrive(corr_bar, count=1)
823710

824-
if config.use_ffma2_scale_rowmax:
825-
qk = _fma_f32x2(qk, gl.full_like(qk, config.qk_scale), -m_ij[:, None])
826-
else:
827-
qk = _mul_f32x2(qk, gl.full_like(qk, config.qk_scale))
828-
qk = _add_f32x2(qk, -m_ij[:, None])
711+
rowmax = float2.pack(-m_ij[:, None].broadcast_to(qk.shape), axis=1)
712+
qk = float2.pack(qk, axis=1)
713+
qk = float2.fma(qk, float2.full_like(qk, config.qk_scale), rowmax)
714+
qk = float2.unpack(qk, axis=1)
829715

830716
# Force the softmax partitions to take turns in the EX2 section. This
831717
# prevents contention for the EX2 unit and improves utilization.
@@ -844,24 +730,12 @@ def _softmax_inner_loop(
844730
if config.use_exp2_turnstile:
845731
mbarrier.arrive(exp_bar, count=1)
846732

847-
if config.use_fadd2_reduce:
848-
p0, p1 = _split_n(p)
849-
l_ij0, l_ij1 = gl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2)
850-
# This is a difference of 1 SASS instruction but it dramatically
851-
# affects instruction scheduling.
852-
alpha = gl.convert_layout(alpha, l_i0.type.layout, assert_trivial=True)
853-
if config.dtype == gl.float8e5:
854-
l_i0, l_i1 = _pairwise_fma_f32x2(l_i0, alpha, l_ij0, l_i1, alpha, l_ij1)
855-
else:
856-
l_i0 = l_i0 * alpha + l_ij0
857-
l_i1 = l_i1 * alpha + l_ij1
858-
else:
859-
l_ij = gl.sum(p, axis=1)
860-
l_i0 = l_i0 * alpha + l_ij
861-
733+
l_ij = float2.pack2(*_split_n(p)).sum(axis=1)
734+
alpha = gl.convert_layout(alpha, l_i.value.type.layout, assert_trivial=True)
735+
l_i = float2.fma(l_i, float2.pack2(alpha, alpha), l_ij)
862736
m_i = m_ij
863737

864-
return m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile
738+
return m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile
865739

866740

867741
@gluon.jit
@@ -876,11 +750,7 @@ def _softmax_tile(
876750
exp_turnstile,
877751
):
878752
qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout)
879-
sum_layout: gl.constexpr = (
880-
_get_split_n_layout(config.qk_layout)
881-
if config.use_fadd2_reduce
882-
else config.qk_layout
883-
)
753+
sum_layout: gl.constexpr = _get_split_n_layout(config.qk_layout)
884754

885755
s_consumer = s_chnl.create_consumer()
886756
corr_producer = corr_chnl.create_producer()
@@ -894,17 +764,12 @@ def _softmax_tile(
894764
offs_m += gl.arange(tile_id * config.SPLIT_M, (1 + tile_id) * config.SPLIT_M)
895765

896766
m_i = gl.full([config.SPLIT_M], -float("inf"), gl.float32, qk_slice_dim1)
897-
l_i0 = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
898767
# Accumulate into 2 row-sums so the reduction can be performed with FADD2.
899-
if config.use_fadd2_reduce:
900-
l_i1 = gl.full(
901-
[config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout)
902-
)
903-
else:
904-
l_i1 = 0
768+
l_i = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
769+
l_i = float2.pack2(l_i, l_i)
905770

906771
if STAGE & 1:
907-
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = (
772+
m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = (
908773
_softmax_inner_loop( #
909774
tile_id,
910775
config,
@@ -915,13 +780,12 @@ def _softmax_tile(
915780
corr_bar, #
916781
offs_m,
917782
m_i,
918-
l_i0,
919-
l_i1,
783+
l_i,
920784
STAGE=4 - STAGE,
921785
)
922786
)
923787
if STAGE & 2:
924-
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = (
788+
m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = (
925789
_softmax_inner_loop( #
926790
tile_id,
927791
config,
@@ -932,16 +796,12 @@ def _softmax_tile(
932796
corr_bar, #
933797
offs_m,
934798
m_i,
935-
l_i0,
936-
l_i1,
799+
l_i,
937800
STAGE=2,
938801
)
939802
)
940-
941-
if config.use_fadd2_reduce:
942-
l_i = l_i0 + l_i1
943-
else:
944-
l_i = l_i0
803+
l_i0, l_i1 = float2.unpack2(l_i)
804+
l_i = l_i0 + l_i1
945805

946806
s_tmem, s_bar, s_consumer = s_consumer.acquire()
947807
m_i_tmem, l_i_tmem = _borrow_s_for_epilogue(config, s_tmem)
@@ -1039,11 +899,14 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
1039899
mbarrier.arrive(corr_bar, count=1)
1040900
alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout)
1041901

902+
alpha = float2.pack(
903+
alpha[:, None].broadcast_to(config.o_shape[0], config.SPLIT_D), axis=1
904+
)
1042905
for i in gl.static_range(config.SPLIT_D_FACTOR):
1043906
o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
1044-
o = o_ref.load(config.o_splitn_layout)
1045-
o = _mul_f32x2(o, alpha[:, None])
1046-
o_ref.store(o)
907+
o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
908+
o = o * alpha
909+
o_ref.store(float2.unpack(o, axis=1))
1047910
mbarrier.arrive(o_bar, count=1)
1048911
return corr_consumer, o_consumer
1049912

@@ -1081,12 +944,16 @@ def _attn_fwd_correction_epilogue(
1081944
)
1082945
SPLIT_N: gl.constexpr = o_smem.type.shape[1] // SPLIT_N_FACTOR
1083946

1084-
scale = 1 / l_i
947+
scale = float2.pack(
948+
(1 / l_i)[:, None].broadcast_to(config.o_shape[0], SPLIT_N), axis=1
949+
)
1085950
for i in gl.static_range(SPLIT_N_FACTOR):
1086951
o_ref = o_tmem.slice(i * SPLIT_N, SPLIT_N)
1087-
o = o_ref.load(config.o_splitn_layout)
1088-
o = _mul_f32x2(o, scale[:, None])
1089-
o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store(o.to(config.dtype))
952+
o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
953+
o = o * scale
954+
o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store(
955+
float2.unpack(o, axis=1).to(config.dtype)
956+
)
1090957

1091958
fence_async_shared()
1092959
mbarrier.arrive(epi_bar, count=1)

0 commit comments

Comments
 (0)