@@ -223,7 +223,6 @@ def compute_blocked_scale_offsets_for_K_groups(
223223def triton_mx_block_rearrange_2d_M_groups (
224224 scales_tensor : torch .Tensor ,
225225 input_group_end_offsets : torch .Tensor ,
226- output_group_start_offsets : torch .Tensor ,
227226) -> torch .Tensor :
228227 """
229228 Rearranges an E8M0 tensor scale to block-scaled swizzle format,
@@ -275,15 +274,14 @@ def triton_mx_block_rearrange_2d_M_groups(
275274 scales_tensor .stride (1 ),
276275 rows ,
277276 cols ,
278- num_groups ,
279277 # Original offsets (to read from)
280278 input_group_end_offsets ,
281279 # Output scales tensor and group offsets after padding (to write to)
282280 output .view (torch .uint8 ),
283281 output .stride (0 ),
284- output_group_start_offsets ,
285282 output_stride_per_block ,
286283 output_stride_per_row_of_blocks ,
284+ num_groups = num_groups ,
287285 BLOCK_ROWS = BLOCK_ROWS ,
288286 BLOCK_COLS = BLOCK_COLS ,
289287 )
@@ -297,13 +295,12 @@ def triton_scale_swizzle_M_groups(
297295 scales_stride_dim1 ,
298296 scale_rows ,
299297 scale_cols ,
300- num_groups ,
301298 orig_offsets , # (num_groups,)
302299 output_scales_ptr ,
303300 output_scales_stride_dim0 ,
304- output_scales_group_offsets , # (num_groups,)
305301 output_stride_per_block ,
306302 output_stride_per_row_of_blocks ,
303+ num_groups : tl .constexpr ,
307304 BLOCK_ROWS : tl .constexpr ,
308305 BLOCK_COLS : tl .constexpr ,
309306):
@@ -316,10 +313,13 @@ def triton_scale_swizzle_M_groups(
316313 input_group_end_row = tl .load (
317314 orig_offsets + group_pid , mask = group_pid < num_groups , other = 0
318315 )
319- # Output scales start row we will begin writing to
320- output_group_start_row = tl .load (
321- output_scales_group_offsets + group_pid , mask = group_pid < num_groups , other = 0
316+
317+ # Calculate this group's start row after blocked format padding, by doing a prefix sum
318+ # of each previous group's padded size.
319+ output_group_start_row = _blocked_group_start_idx (
320+ group_pid , orig_offsets , num_groups , 128
322321 )
322+
323323 # Calculate destination indices for each row and col in block swizzled layout.
324324 # We can reuse this swizzle transformation on each block of data we read.
325325 row_offs = tl .arange (0 , BLOCK_ROWS )[:, None ]
@@ -489,7 +489,6 @@ def triton_scale_swizzle_per_group_3d(
489489def triton_mx_block_rearrange_2d_K_groups (
490490 scales_tensor : torch .Tensor ,
491491 input_group_end_offsets : torch .Tensor ,
492- output_group_start_offsets : torch .Tensor ,
493492) -> torch .Tensor :
494493 """
495494 Rearranges an E8M0 tensor scale to block-scaled swizzle format on a per group basis,
@@ -538,13 +537,10 @@ def triton_mx_block_rearrange_2d_K_groups(
538537 rows ,
539538 cols ,
540539 padded_rows ,
541- num_groups ,
542- # Original offsets (to read from)
543540 input_group_end_offsets ,
544- # Output scales tensor and group offsets after padding (to write to)
545541 output .view (torch .uint8 ),
546- output_group_start_offsets ,
547542 output_stride_per_block ,
543+ num_groups = num_groups ,
548544 BLOCK_ROWS = BLOCK_ROWS ,
549545 BLOCK_COLS = BLOCK_COLS ,
550546 DEBUG = False ,
@@ -560,11 +556,10 @@ def triton_scale_swizzle_2d_K_groups(
560556 scale_rows ,
561557 scale_cols ,
562558 padded_rows ,
563- num_groups ,
564559 orig_offsets , # (num_groups,)
565560 output_scales_ptr ,
566- output_scales_group_offsets , # (num_groups,)
567561 output_stride_per_block ,
562+ num_groups : tl .constexpr ,
568563 BLOCK_ROWS : tl .constexpr ,
569564 BLOCK_COLS : tl .constexpr ,
570565 DEBUG : tl .constexpr = False ,
@@ -578,8 +573,11 @@ def triton_scale_swizzle_2d_K_groups(
578573 )
579574 input_group_end_col = tl .load (orig_offsets + group_pid )
580575
581- # Output scales start row we will begin writing to
582- output_group_start_col = tl .load (output_scales_group_offsets + group_pid )
576+ # Calculate this group's start row after blocked format padding, by doing a prefix sum
577+ # of each previous group's padded size.
578+ output_group_start_col = _blocked_group_start_idx (
579+ group_pid , orig_offsets , num_groups , 4
580+ )
583581
584582 row_offs = tl .arange (0 , BLOCK_ROWS )[:, None ]
585583 col_offs = tl .arange (0 , BLOCK_COLS )[None , :]
@@ -651,6 +649,31 @@ def _dest_indices_for_block(
651649 return dest_indices_flat
652650
653651
652+ @triton .jit
653+ def _blocked_group_start_idx (
654+ group_pid ,
655+ orig_offsets ,
656+ num_groups : tl .constexpr ,
657+ padding_size : tl .constexpr ,
658+ ):
659+ """Prefix sum to compute the start index of a given group."""
660+ offsets = tl .load (orig_offsets + tl .arange (0 , num_groups ))
661+ prev_offsets = tl .load (
662+ orig_offsets + tl .arange (0 , num_groups ) - 1 ,
663+ mask = tl .arange (0 , num_groups ) > 0 ,
664+ other = 0 ,
665+ )
666+ group_sizes = tl .where (
667+ tl .arange (0 , num_groups ) > 0 ,
668+ offsets - prev_offsets ,
669+ offsets ,
670+ )
671+ padded_sizes = tl .cdiv (group_sizes , padding_size ) * padding_size
672+ prefix_mask = tl .arange (0 , num_groups ) < group_pid
673+ group_start_idx = tl .sum (tl .where (prefix_mask , padded_sizes , 0 ))
674+ return group_start_idx
675+
676+
654677mxfp8_cuda_extension_available = False
655678if is_sm_at_least_100 ():
656679 try :
0 commit comments