88from triton .experimental .gluon import language as gl
99from triton .experimental .gluon .language .nvidia .blackwell import (
1010 allocate_tensor_memory ,
11+ float2 ,
1112 get_tmem_32x32b_reg_layout ,
1213 mbarrier ,
1314 tcgen05_commit ,
@@ -243,9 +244,7 @@ class AttentionConfig:
243244 alpha_2d_layout : gl .constexpr
244245
245246 num_kv_buffers : gl .constexpr
246- use_fadd2_reduce : gl .constexpr
247247 use_exp2_turnstile : gl .constexpr
248- use_ffma2_scale_rowmax : gl .constexpr
249248
250249 def __init__ (
251250 self ,
@@ -290,13 +289,13 @@ def __init__(
290289 qk_instr_shape = get_mma_instr_shape (self .qk_shape , gl .float32 )
291290 o_instr_shape = get_mma_instr_shape (self .o_shape , gl .float32 )
292291 self .qk_tmem_layout = gl .constexpr (
293- TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), unpacked = True )
292+ TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), col_stride = 1 )
294293 )
295294 self .o_tmem_layout = gl .constexpr (
296- TensorMemoryLayout ((o_instr_shape [0 ], o_instr_shape [1 ]), unpacked = True )
295+ TensorMemoryLayout ((o_instr_shape [0 ], o_instr_shape [1 ]), col_stride = 1 )
297296 )
298297 self .p_tmem_layout = gl .constexpr (
299- TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), unpacked = False )
298+ TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), col_stride = 1 )
300299 )
301300
302301 self .qk_layout = gl .constexpr (
@@ -321,17 +320,13 @@ def __init__(
321320 gl .BlockedLayout ([1 , 1 ], [32 , 1 ], [self .num_warps , 1 ], [0 , 1 ])
322321 )
323322
324- is_fp16 = dtype .value in [gl .float16 , gl .bfloat16 ]
323+ is_fp16 = self . dtype .value in [gl .float16 , gl .bfloat16 ]
325324 if is_fp16 :
326325 self .num_kv_buffers = gl .constexpr (3 if HEAD_DIM == 128 else 6 )
327326 else :
328327 self .num_kv_buffers = gl .constexpr (4 if HEAD_DIM == 128 else 8 )
329328
330- self .use_fadd2_reduce = gl .constexpr (HEAD_DIM == 64 )
331329 self .use_exp2_turnstile = gl .constexpr (HEAD_DIM == 64 )
332- self .use_ffma2_scale_rowmax = gl .constexpr (
333- HEAD_DIM == 128 or is_fp16 == (STAGE == 3 )
334- )
335330
336331 @gluon .jit
337332 def get_program (self , pid_m , pid_n ):
@@ -421,113 +416,6 @@ def get_loop_bounds(self, STAGE: gl.constexpr):
421416 return lo , hi
422417
423418
424- # ===-----------------------------------------------------------------------===#
425- # float2
426- # ===-----------------------------------------------------------------------===#
427-
428-
429- @gluon .jit
430- def _add_f32x2 (a , b ):
431- return gl .inline_asm_elementwise (
432- """
433- {
434- .reg .b64 ra, rb, rc;
435- mov.b64 ra, { $2, $3 };
436- mov.b64 rb, { $4, $5 };
437- add.f32x2 rc, ra, rb;
438- mov.b64 { $0, $1 }, rc;
439- }
440- """ ,
441- "=r,=r,r,r,r,r" ,
442- [a , b ],
443- dtype = gl .float32 ,
444- is_pure = True ,
445- pack = 2 ,
446- )
447-
448-
449- @gluon .jit
450- def _mul_f32x2 (a , b ):
451- return gl .inline_asm_elementwise (
452- """
453- {
454- .reg .b64 ra, rb, rc;
455- mov.b64 ra, { $2, $3 };
456- mov.b64 rb, { $4, $5 };
457- mul.f32x2 rc, ra, rb;
458- mov.b64 { $0, $1 }, rc;
459- }
460- """ ,
461- "=r,=r,r,r,r,r" ,
462- [a , b ],
463- dtype = gl .float32 ,
464- is_pure = True ,
465- pack = 2 ,
466- )
467-
468-
469- @gluon .jit
470- def _fma_f32x2 (a , b , c ):
471- return gl .inline_asm_elementwise (
472- """
473- {
474- .reg .b64 ra, rb, rc, rd;
475- mov.b64 ra, { $2, $3 };
476- mov.b64 rb, { $4, $5 };
477- mov.b64 rc, { $6, $7 };
478- fma.rn.f32x2 rd, ra, rb, rc;
479- mov.b64 { $0, $1 }, rd;
480- }
481- """ ,
482- "=r,=r,r,r,r,r,r,r" ,
483- [a , b , c ],
484- dtype = gl .float32 ,
485- is_pure = True ,
486- pack = 2 ,
487- )
488-
489-
490- @gluon .jit
491- def _reduce_fadd2 (p0a , p1a , p0b , p1b ):
492- return gl .inline_asm_elementwise (
493- """
494- {
495- .reg .b64 rc, ra, rb;
496- mov.b64 ra, { $2, $4 };
497- mov.b64 rb, { $3, $5 };
498- add.f32x2 rc, ra, rb;
499- mov.b64 { $0, $1 }, rc;
500- }
501- """ ,
502- "=r,=r,r,r,r,r" ,
503- [p0a , p0b , p1a , p1b ],
504- dtype = [gl .float32 , gl .float32 ],
505- is_pure = True ,
506- pack = 1 ,
507- )
508-
509-
510- @gluon .jit
511- def _pairwise_fma_f32x2 (a0 , b0 , c0 , a1 , b1 , c1 ):
512- return gl .inline_asm_elementwise (
513- """
514- {
515- .reg .b64 rd, ra, rb, rc;
516- mov.b64 ra, { $2, $5 };
517- mov.b64 rb, { $3, $6 };
518- mov.b64 rc, { $4, $7 };
519- fma.rn.f32x2 rd, ra, rb, rc;
520- mov.b64 { $0, $1 }, rd;
521- }
522- """ ,
523- "=r,=r,r,r,r,r,r,r" ,
524- [a0 , b0 , c0 , a1 , b1 , c1 ],
525- dtype = [gl .float32 , gl .float32 ],
526- is_pure = True ,
527- pack = 1 ,
528- )
529-
530-
531419# ===-----------------------------------------------------------------------===#
532420# _gluon_attn
533421# ===-----------------------------------------------------------------------===#
@@ -542,15 +430,15 @@ def _borrow_s_as_p(config, s_tmem):
542430@gluon .jit
543431def _borrow_s_as_alpha (config , s_tmem ):
544432 alpha_tmem = s_tmem .slice (config .BLOCK_N // 2 , 1 )
545- alpha_layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], unpacked = True )
433+ alpha_layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], col_stride = 1 )
546434 return alpha_tmem ._reinterpret (gl .float32 , [config .SPLIT_M , 1 ], alpha_layout )
547435
548436
549437@gluon .jit
550438def _borrow_s_for_epilogue (config , s_tmem ):
551439 m_i_tmem = s_tmem .slice (config .BLOCK_N // 2 + 1 , 1 )
552440 l_i_tmem = s_tmem .slice (config .BLOCK_N // 2 + 2 , 1 )
553- layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], unpacked = True )
441+ layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], col_stride = 1 )
554442 m_i_tmem = m_i_tmem ._reinterpret (gl .float32 , [config .SPLIT_M , 1 ], layout )
555443 l_i_tmem = l_i_tmem ._reinterpret (gl .float32 , [config .SPLIT_M , 1 ], layout )
556444 return m_i_tmem , l_i_tmem
@@ -798,8 +686,7 @@ def _softmax_inner_loop(
798686 corr_bar , #
799687 offs_m ,
800688 m_i ,
801- l_i0 ,
802- l_i1 ,
689+ l_i ,
803690 STAGE : gl .constexpr ,
804691):
805692 lo , hi = prog .get_loop_bounds (STAGE )
@@ -821,11 +708,10 @@ def _softmax_inner_loop(
821708 )
822709 mbarrier .arrive (corr_bar , count = 1 )
823710
824- if config .use_ffma2_scale_rowmax :
825- qk = _fma_f32x2 (qk , gl .full_like (qk , config .qk_scale ), - m_ij [:, None ])
826- else :
827- qk = _mul_f32x2 (qk , gl .full_like (qk , config .qk_scale ))
828- qk = _add_f32x2 (qk , - m_ij [:, None ])
711+ rowmax = float2 .pack (- m_ij [:, None ].broadcast_to (qk .shape ), axis = 1 )
712+ qk = float2 .pack (qk , axis = 1 )
713+ qk = float2 .fma (qk , float2 .full_like (qk , config .qk_scale ), rowmax )
714+ qk = float2 .unpack (qk , axis = 1 )
829715
830716 # Force the softmax partitions to take turns in the EX2 section. This
831717 # prevents contention for the EX2 unit and improves utilization.
@@ -844,24 +730,12 @@ def _softmax_inner_loop(
844730 if config .use_exp2_turnstile :
845731 mbarrier .arrive (exp_bar , count = 1 )
846732
847- if config .use_fadd2_reduce :
848- p0 , p1 = _split_n (p )
849- l_ij0 , l_ij1 = gl .reduce ((p0 , p1 ), axis = 1 , combine_fn = _reduce_fadd2 )
850- # This is a difference of 1 SASS instruction but it dramatically
851- # affects instruction scheduling.
852- alpha = gl .convert_layout (alpha , l_i0 .type .layout , assert_trivial = True )
853- if config .dtype == gl .float8e5 :
854- l_i0 , l_i1 = _pairwise_fma_f32x2 (l_i0 , alpha , l_ij0 , l_i1 , alpha , l_ij1 )
855- else :
856- l_i0 = l_i0 * alpha + l_ij0
857- l_i1 = l_i1 * alpha + l_ij1
858- else :
859- l_ij = gl .sum (p , axis = 1 )
860- l_i0 = l_i0 * alpha + l_ij
861-
733+ l_ij = float2 .pack2 (* _split_n (p )).sum (axis = 1 )
734+ alpha = gl .convert_layout (alpha , l_i .value .type .layout , assert_trivial = True )
735+ l_i = float2 .fma (l_i , float2 .pack2 (alpha , alpha ), l_ij )
862736 m_i = m_ij
863737
864- return m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile
738+ return m_i , l_i , corr_bar , s_consumer , corr_producer , exp_turnstile
865739
866740
867741@gluon .jit
@@ -876,11 +750,7 @@ def _softmax_tile(
876750 exp_turnstile ,
877751):
878752 qk_slice_dim1 : gl .constexpr = gl .SliceLayout (1 , config .qk_layout )
879- sum_layout : gl .constexpr = (
880- _get_split_n_layout (config .qk_layout )
881- if config .use_fadd2_reduce
882- else config .qk_layout
883- )
753+ sum_layout : gl .constexpr = _get_split_n_layout (config .qk_layout )
884754
885755 s_consumer = s_chnl .create_consumer ()
886756 corr_producer = corr_chnl .create_producer ()
@@ -894,17 +764,12 @@ def _softmax_tile(
894764 offs_m += gl .arange (tile_id * config .SPLIT_M , (1 + tile_id ) * config .SPLIT_M )
895765
896766 m_i = gl .full ([config .SPLIT_M ], - float ("inf" ), gl .float32 , qk_slice_dim1 )
897- l_i0 = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , gl .SliceLayout (1 , sum_layout ))
898767 # Accumulate into 2 row-sums so the reduction can be performed with FADD2.
899- if config .use_fadd2_reduce :
900- l_i1 = gl .full (
901- [config .SPLIT_M ], 0.0 , gl .float32 , gl .SliceLayout (1 , sum_layout )
902- )
903- else :
904- l_i1 = 0
768+ l_i = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , gl .SliceLayout (1 , sum_layout ))
769+ l_i = float2 .pack2 (l_i , l_i )
905770
906771 if STAGE & 1 :
907- m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = (
772+ m_i , l_i , corr_bar , s_consumer , corr_producer , exp_turnstile = (
908773 _softmax_inner_loop ( #
909774 tile_id ,
910775 config ,
@@ -915,13 +780,12 @@ def _softmax_tile(
915780 corr_bar , #
916781 offs_m ,
917782 m_i ,
918- l_i0 ,
919- l_i1 ,
783+ l_i ,
920784 STAGE = 4 - STAGE ,
921785 )
922786 )
923787 if STAGE & 2 :
924- m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = (
788+ m_i , l_i , corr_bar , s_consumer , corr_producer , exp_turnstile = (
925789 _softmax_inner_loop ( #
926790 tile_id ,
927791 config ,
@@ -932,16 +796,12 @@ def _softmax_tile(
932796 corr_bar , #
933797 offs_m ,
934798 m_i ,
935- l_i0 ,
936- l_i1 ,
799+ l_i ,
937800 STAGE = 2 ,
938801 )
939802 )
940-
941- if config .use_fadd2_reduce :
942- l_i = l_i0 + l_i1
943- else :
944- l_i = l_i0
803+ l_i0 , l_i1 = float2 .unpack2 (l_i )
804+ l_i = l_i0 + l_i1
945805
946806 s_tmem , s_bar , s_consumer = s_consumer .acquire ()
947807 m_i_tmem , l_i_tmem = _borrow_s_for_epilogue (config , s_tmem )
@@ -1039,11 +899,14 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
1039899 mbarrier .arrive (corr_bar , count = 1 )
1040900 alpha = gl .convert_layout (alpha .reshape ([config .SPLIT_M ]), alpha_layout )
1041901
902+ alpha = float2 .pack (
903+ alpha [:, None ].broadcast_to (config .o_shape [0 ], config .SPLIT_D ), axis = 1
904+ )
1042905 for i in gl .static_range (config .SPLIT_D_FACTOR ):
1043906 o_ref = o_tmem .slice (i * config .SPLIT_D , config .SPLIT_D )
1044- o = o_ref .load (config .o_splitn_layout )
1045- o = _mul_f32x2 ( o , alpha [:, None ])
1046- o_ref .store (o )
907+ o = float2 . pack ( o_ref .load (config .o_splitn_layout ), axis = 1 )
908+ o = o * alpha
909+ o_ref .store (float2 . unpack ( o , axis = 1 ) )
1047910 mbarrier .arrive (o_bar , count = 1 )
1048911 return corr_consumer , o_consumer
1049912
@@ -1081,12 +944,16 @@ def _attn_fwd_correction_epilogue(
1081944 )
1082945 SPLIT_N : gl .constexpr = o_smem .type .shape [1 ] // SPLIT_N_FACTOR
1083946
1084- scale = 1 / l_i
947+ scale = float2 .pack (
948+ (1 / l_i )[:, None ].broadcast_to (config .o_shape [0 ], SPLIT_N ), axis = 1
949+ )
1085950 for i in gl .static_range (SPLIT_N_FACTOR ):
1086951 o_ref = o_tmem .slice (i * SPLIT_N , SPLIT_N )
1087- o = o_ref .load (config .o_splitn_layout )
1088- o = _mul_f32x2 (o , scale [:, None ])
1089- o_smem .slice (i * SPLIT_N , SPLIT_N , dim = 1 ).store (o .to (config .dtype ))
952+ o = float2 .pack (o_ref .load (config .o_splitn_layout ), axis = 1 )
953+ o = o * scale
954+ o_smem .slice (i * SPLIT_N , SPLIT_N , dim = 1 ).store (
955+ float2 .unpack (o , axis = 1 ).to (config .dtype )
956+ )
1090957
1091958 fence_async_shared ()
1092959 mbarrier .arrive (epi_bar , count = 1 )
0 commit comments