88from triton .experimental .gluon import language as gl
99from triton .experimental .gluon .language .nvidia .blackwell import (
1010 allocate_tensor_memory ,
11+ float2 ,
1112 get_tmem_32x32b_reg_layout ,
1213 mbarrier ,
1314 tcgen05_commit ,
@@ -69,6 +70,7 @@ def increment(self):
6970
7071
7172def Channel (T , alloc_fn ):
73+
7274 @aggregate
7375 class ChannelType :
7476 mem : T
@@ -243,9 +245,7 @@ class AttentionConfig:
243245 alpha_2d_layout : gl .constexpr
244246
245247 num_kv_buffers : gl .constexpr
246- use_fadd2_reduce : gl .constexpr
247248 use_exp2_turnstile : gl .constexpr
248- use_ffma2_scale_rowmax : gl .constexpr
249249
250250 def __init__ (
251251 self ,
@@ -290,13 +290,13 @@ def __init__(
290290 qk_instr_shape = get_mma_instr_shape (self .qk_shape , gl .float32 )
291291 o_instr_shape = get_mma_instr_shape (self .o_shape , gl .float32 )
292292 self .qk_tmem_layout = gl .constexpr (
293- TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), unpacked = True )
293+ TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), col_stride = 1 )
294294 )
295295 self .o_tmem_layout = gl .constexpr (
296- TensorMemoryLayout ((o_instr_shape [0 ], o_instr_shape [1 ]), unpacked = True )
296+ TensorMemoryLayout ((o_instr_shape [0 ], o_instr_shape [1 ]), col_stride = 1 )
297297 )
298298 self .p_tmem_layout = gl .constexpr (
299- TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), unpacked = False )
299+ TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), col_stride = 1 )
300300 )
301301
302302 self .qk_layout = gl .constexpr (
@@ -321,17 +321,13 @@ def __init__(
321321 gl .BlockedLayout ([1 , 1 ], [32 , 1 ], [self .num_warps , 1 ], [0 , 1 ])
322322 )
323323
324- is_fp16 = dtype .value in [gl .float16 , gl .bfloat16 ]
324+ is_fp16 = self . dtype .value in [gl .float16 , gl .bfloat16 ]
325325 if is_fp16 :
326326 self .num_kv_buffers = gl .constexpr (3 if HEAD_DIM == 128 else 6 )
327327 else :
328328 self .num_kv_buffers = gl .constexpr (4 if HEAD_DIM == 128 else 8 )
329329
330- self .use_fadd2_reduce = gl .constexpr (HEAD_DIM == 64 )
331330 self .use_exp2_turnstile = gl .constexpr (HEAD_DIM == 64 )
332- self .use_ffma2_scale_rowmax = gl .constexpr (
333- HEAD_DIM == 128 or is_fp16 == (STAGE == 3 )
334- )
335331
336332 @gluon .jit
337333 def get_program (self , pid_m , pid_n ):
@@ -421,113 +417,6 @@ def get_loop_bounds(self, STAGE: gl.constexpr):
421417 return lo , hi
422418
423419
424- # ===-----------------------------------------------------------------------===#
425- # float2
426- # ===-----------------------------------------------------------------------===#
427-
428-
429- @gluon .jit
430- def _add_f32x2 (a , b ):
431- return gl .inline_asm_elementwise (
432- """
433- {
434- .reg .b64 ra, rb, rc;
435- mov.b64 ra, { $2, $3 };
436- mov.b64 rb, { $4, $5 };
437- add.f32x2 rc, ra, rb;
438- mov.b64 { $0, $1 }, rc;
439- }
440- """ ,
441- "=r,=r,r,r,r,r" ,
442- [a , b ],
443- dtype = gl .float32 ,
444- is_pure = True ,
445- pack = 2 ,
446- )
447-
448-
449- @gluon .jit
450- def _mul_f32x2 (a , b ):
451- return gl .inline_asm_elementwise (
452- """
453- {
454- .reg .b64 ra, rb, rc;
455- mov.b64 ra, { $2, $3 };
456- mov.b64 rb, { $4, $5 };
457- mul.f32x2 rc, ra, rb;
458- mov.b64 { $0, $1 }, rc;
459- }
460- """ ,
461- "=r,=r,r,r,r,r" ,
462- [a , b ],
463- dtype = gl .float32 ,
464- is_pure = True ,
465- pack = 2 ,
466- )
467-
468-
469- @gluon .jit
470- def _fma_f32x2 (a , b , c ):
471- return gl .inline_asm_elementwise (
472- """
473- {
474- .reg .b64 ra, rb, rc, rd;
475- mov.b64 ra, { $2, $3 };
476- mov.b64 rb, { $4, $5 };
477- mov.b64 rc, { $6, $7 };
478- fma.rn.f32x2 rd, ra, rb, rc;
479- mov.b64 { $0, $1 }, rd;
480- }
481- """ ,
482- "=r,=r,r,r,r,r,r,r" ,
483- [a , b , c ],
484- dtype = gl .float32 ,
485- is_pure = True ,
486- pack = 2 ,
487- )
488-
489-
490- @gluon .jit
491- def _reduce_fadd2 (p0a , p1a , p0b , p1b ):
492- return gl .inline_asm_elementwise (
493- """
494- {
495- .reg .b64 rc, ra, rb;
496- mov.b64 ra, { $2, $4 };
497- mov.b64 rb, { $3, $5 };
498- add.f32x2 rc, ra, rb;
499- mov.b64 { $0, $1 }, rc;
500- }
501- """ ,
502- "=r,=r,r,r,r,r" ,
503- [p0a , p0b , p1a , p1b ],
504- dtype = [gl .float32 , gl .float32 ],
505- is_pure = True ,
506- pack = 1 ,
507- )
508-
509-
510- @gluon .jit
511- def _pairwise_fma_f32x2 (a0 , b0 , c0 , a1 , b1 , c1 ):
512- return gl .inline_asm_elementwise (
513- """
514- {
515- .reg .b64 rd, ra, rb, rc;
516- mov.b64 ra, { $2, $5 };
517- mov.b64 rb, { $3, $6 };
518- mov.b64 rc, { $4, $7 };
519- fma.rn.f32x2 rd, ra, rb, rc;
520- mov.b64 { $0, $1 }, rd;
521- }
522- """ ,
523- "=r,=r,r,r,r,r,r,r" ,
524- [a0 , b0 , c0 , a1 , b1 , c1 ],
525- dtype = [gl .float32 , gl .float32 ],
526- is_pure = True ,
527- pack = 1 ,
528- )
529-
530-
531420# ===-----------------------------------------------------------------------===#
532421# _gluon_attn
533422# ===-----------------------------------------------------------------------===#
@@ -542,15 +431,15 @@ def _borrow_s_as_p(config, s_tmem):
542431@gluon .jit
543432def _borrow_s_as_alpha (config , s_tmem ):
544433 alpha_tmem = s_tmem .slice (config .BLOCK_N // 2 , 1 )
545- alpha_layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], unpacked = True )
434+ alpha_layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], col_stride = 1 )
546435 return alpha_tmem ._reinterpret (gl .float32 , [config .SPLIT_M , 1 ], alpha_layout )
547436
548437
549438@gluon .jit
550439def _borrow_s_for_epilogue (config , s_tmem ):
551440 m_i_tmem = s_tmem .slice (config .BLOCK_N // 2 + 1 , 1 )
552441 l_i_tmem = s_tmem .slice (config .BLOCK_N // 2 + 2 , 1 )
553- layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], unpacked = True )
442+ layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], col_stride = 1 )
554443 m_i_tmem = m_i_tmem ._reinterpret (gl .float32 , [config .SPLIT_M , 1 ], layout )
555444 l_i_tmem = l_i_tmem ._reinterpret (gl .float32 , [config .SPLIT_M , 1 ], layout )
556445 return m_i_tmem , l_i_tmem
@@ -798,8 +687,7 @@ def _softmax_inner_loop(
798687 corr_bar , #
799688 offs_m ,
800689 m_i ,
801- l_i0 ,
802- l_i1 ,
690+ l_i ,
803691 STAGE : gl .constexpr ,
804692):
805693 lo , hi = prog .get_loop_bounds (STAGE )
@@ -821,11 +709,10 @@ def _softmax_inner_loop(
821709 )
822710 mbarrier .arrive (corr_bar , count = 1 )
823711
824- if config .use_ffma2_scale_rowmax :
825- qk = _fma_f32x2 (qk , gl .full_like (qk , config .qk_scale ), - m_ij [:, None ])
826- else :
827- qk = _mul_f32x2 (qk , gl .full_like (qk , config .qk_scale ))
828- qk = _add_f32x2 (qk , - m_ij [:, None ])
712+ rowmax = float2 .pack (- m_ij [:, None ].broadcast_to (qk .shape ), axis = 1 )
713+ qk = float2 .pack (qk , axis = 1 )
714+ qk = float2 .fma (qk , float2 .full_like (qk , config .qk_scale ), rowmax )
715+ qk = float2 .unpack (qk , axis = 1 )
829716
830717 # Force the softmax partitions to take turns in the EX2 section. This
831718 # prevents contention for the EX2 unit and improves utilization.
@@ -844,24 +731,12 @@ def _softmax_inner_loop(
844731 if config .use_exp2_turnstile :
845732 mbarrier .arrive (exp_bar , count = 1 )
846733
847- if config .use_fadd2_reduce :
848- p0 , p1 = _split_n (p )
849- l_ij0 , l_ij1 = gl .reduce ((p0 , p1 ), axis = 1 , combine_fn = _reduce_fadd2 )
850- # This is a difference of 1 SASS instruction but it dramatically
851- # affects instruction scheduling.
852- alpha = gl .convert_layout (alpha , l_i0 .type .layout , assert_trivial = True )
853- if config .dtype == gl .float8e5 :
854- l_i0 , l_i1 = _pairwise_fma_f32x2 (l_i0 , alpha , l_ij0 , l_i1 , alpha , l_ij1 )
855- else :
856- l_i0 = l_i0 * alpha + l_ij0
857- l_i1 = l_i1 * alpha + l_ij1
858- else :
859- l_ij = gl .sum (p , axis = 1 )
860- l_i0 = l_i0 * alpha + l_ij
861-
734+ l_ij = float2 .pack2 (* _split_n (p )).sum (axis = 1 )
735+ alpha = gl .convert_layout (alpha , l_i .value .type .layout , assert_trivial = True )
736+ l_i = float2 .fma (l_i , float2 .pack2 (alpha , alpha ), l_ij )
862737 m_i = m_ij
863738
864- return m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile
739+ return m_i , l_i , corr_bar , s_consumer , corr_producer , exp_turnstile
865740
866741
867742@gluon .jit
@@ -876,11 +751,7 @@ def _softmax_tile(
876751 exp_turnstile ,
877752):
878753 qk_slice_dim1 : gl .constexpr = gl .SliceLayout (1 , config .qk_layout )
879- sum_layout : gl .constexpr = (
880- _get_split_n_layout (config .qk_layout )
881- if config .use_fadd2_reduce
882- else config .qk_layout
883- )
754+ sum_layout : gl .constexpr = _get_split_n_layout (config .qk_layout )
884755
885756 s_consumer = s_chnl .create_consumer ()
886757 corr_producer = corr_chnl .create_producer ()
@@ -894,17 +765,12 @@ def _softmax_tile(
894765 offs_m += gl .arange (tile_id * config .SPLIT_M , (1 + tile_id ) * config .SPLIT_M )
895766
896767 m_i = gl .full ([config .SPLIT_M ], - float ("inf" ), gl .float32 , qk_slice_dim1 )
897- l_i0 = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , gl .SliceLayout (1 , sum_layout ))
898768 # Accumulate into 2 row-sums so the reduction can be performed with FADD2.
899- if config .use_fadd2_reduce :
900- l_i1 = gl .full (
901- [config .SPLIT_M ], 0.0 , gl .float32 , gl .SliceLayout (1 , sum_layout )
902- )
903- else :
904- l_i1 = 0
769+ l_i = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , gl .SliceLayout (1 , sum_layout ))
770+ l_i = float2 .pack2 (l_i , l_i )
905771
906772 if STAGE & 1 :
907- m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = (
773+ m_i , l_i , corr_bar , s_consumer , corr_producer , exp_turnstile = (
908774 _softmax_inner_loop ( #
909775 tile_id ,
910776 config ,
@@ -915,13 +781,12 @@ def _softmax_tile(
915781 corr_bar , #
916782 offs_m ,
917783 m_i ,
918- l_i0 ,
919- l_i1 ,
784+ l_i ,
920785 STAGE = 4 - STAGE ,
921786 )
922787 )
923788 if STAGE & 2 :
924- m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = (
789+ m_i , l_i , corr_bar , s_consumer , corr_producer , exp_turnstile = (
925790 _softmax_inner_loop ( #
926791 tile_id ,
927792 config ,
@@ -932,16 +797,12 @@ def _softmax_tile(
932797 corr_bar , #
933798 offs_m ,
934799 m_i ,
935- l_i0 ,
936- l_i1 ,
800+ l_i ,
937801 STAGE = 2 ,
938802 )
939803 )
940-
941- if config .use_fadd2_reduce :
942- l_i = l_i0 + l_i1
943- else :
944- l_i = l_i0
804+ l_i0 , l_i1 = float2 .unpack2 (l_i )
805+ l_i = l_i0 + l_i1
945806
946807 s_tmem , s_bar , s_consumer = s_consumer .acquire ()
947808 m_i_tmem , l_i_tmem = _borrow_s_for_epilogue (config , s_tmem )
@@ -1039,11 +900,14 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
1039900 mbarrier .arrive (corr_bar , count = 1 )
1040901 alpha = gl .convert_layout (alpha .reshape ([config .SPLIT_M ]), alpha_layout )
1041902
903+ alpha = float2 .pack (
904+ alpha [:, None ].broadcast_to (config .o_shape [0 ], config .SPLIT_D ), axis = 1
905+ )
1042906 for i in gl .static_range (config .SPLIT_D_FACTOR ):
1043907 o_ref = o_tmem .slice (i * config .SPLIT_D , config .SPLIT_D )
1044- o = o_ref .load (config .o_splitn_layout )
1045- o = _mul_f32x2 ( o , alpha [:, None ])
1046- o_ref .store (o )
908+ o = float2 . pack ( o_ref .load (config .o_splitn_layout ), axis = 1 )
909+ o = o * alpha
910+ o_ref .store (float2 . unpack ( o , axis = 1 ) )
1047911 mbarrier .arrive (o_bar , count = 1 )
1048912 return corr_consumer , o_consumer
1049913
@@ -1081,12 +945,16 @@ def _attn_fwd_correction_epilogue(
1081945 )
1082946 SPLIT_N : gl .constexpr = o_smem .type .shape [1 ] // SPLIT_N_FACTOR
1083947
1084- scale = 1 / l_i
948+ scale = float2 .pack (
949+ (1 / l_i )[:, None ].broadcast_to (config .o_shape [0 ], SPLIT_N ), axis = 1
950+ )
1085951 for i in gl .static_range (SPLIT_N_FACTOR ):
1086952 o_ref = o_tmem .slice (i * SPLIT_N , SPLIT_N )
1087- o = o_ref .load (config .o_splitn_layout )
1088- o = _mul_f32x2 (o , scale [:, None ])
1089- o_smem .slice (i * SPLIT_N , SPLIT_N , dim = 1 ).store (o .to (config .dtype ))
953+ o = float2 .pack (o_ref .load (config .o_splitn_layout ), axis = 1 )
954+ o = o * scale
955+ o_smem .slice (i * SPLIT_N , SPLIT_N , dim = 1 ).store (
956+ float2 .unpack (o , axis = 1 ).to (config .dtype )
957+ )
1090958
1091959 fence_async_shared ()
1092960 mbarrier .arrive (epi_bar , count = 1 )
0 commit comments