@@ -223,7 +223,6 @@ def compute_blocked_scale_offsets_for_K_groups(
223223def triton_mx_block_rearrange_2d_M_groups (
224224 scales_tensor : torch .Tensor ,
225225 input_group_end_offsets : torch .Tensor ,
226- output_group_start_offsets : torch .Tensor ,
227226) -> torch .Tensor :
228227 """
229228 Rearranges an E8M0 tensor scale to block-scaled swizzle format,
@@ -275,35 +274,58 @@ def triton_mx_block_rearrange_2d_M_groups(
275274 scales_tensor .stride (1 ),
276275 rows ,
277276 cols ,
278- num_groups ,
279277 # Original offsets (to read from)
280278 input_group_end_offsets ,
281279 # Output scales tensor and group offsets after padding (to write to)
282280 output .view (torch .uint8 ),
283281 output .stride (0 ),
284- output_group_start_offsets ,
285282 output_stride_per_block ,
286283 output_stride_per_row_of_blocks ,
284+ num_groups = num_groups ,
287285 BLOCK_ROWS = BLOCK_ROWS ,
288286 BLOCK_COLS = BLOCK_COLS ,
289287 )
290288 return output
291289
292290
291+ @triton .jit
292+ def _blocked_group_start_idx (
293+ group_pid ,
294+ orig_offsets ,
295+ num_groups : tl .constexpr ,
296+ padding_size : tl .constexpr ,
297+ ):
298+ """Prefix sum to compute the start index of a given group."""
299+ offsets = tl .load (orig_offsets + tl .arange (0 , num_groups ))
300+ group_sizes = tl .where (
301+ tl .arange (0 , num_groups ) > 0 ,
302+ offsets
303+ - tl .load (
304+ orig_offsets + tl .arange (0 , num_groups ) - 1 ,
305+ mask = tl .arange (0 , num_groups ) > 0 ,
306+ other = 0 ,
307+ ),
308+ offsets ,
309+ )
310+ padded_sizes = tl .cdiv (group_sizes , padding_size ) * padding_size
311+ prefix_mask = tl .arange (0 , num_groups ) < group_pid
312+ group_start_idx = tl .sum (tl .where (prefix_mask , padded_sizes , 0 ))
313+ return group_start_idx
314+
315+
293316@triton .jit
294317def triton_scale_swizzle_M_groups (
295318 scales_ptr , # (M, K//block_size)
296319 scales_stride_dim0 ,
297320 scales_stride_dim1 ,
298321 scale_rows ,
299322 scale_cols ,
300- num_groups ,
301323 orig_offsets , # (num_groups,)
302324 output_scales_ptr ,
303325 output_scales_stride_dim0 ,
304- output_scales_group_offsets , # (num_groups,)
305326 output_stride_per_block ,
306327 output_stride_per_row_of_blocks ,
328+ num_groups : tl .constexpr ,
307329 BLOCK_ROWS : tl .constexpr ,
308330 BLOCK_COLS : tl .constexpr ,
309331):
@@ -316,10 +338,13 @@ def triton_scale_swizzle_M_groups(
316338 input_group_end_row = tl .load (
317339 orig_offsets + group_pid , mask = group_pid < num_groups , other = 0
318340 )
319- # Output scales start row we will begin writing to
320- output_group_start_row = tl .load (
321- output_scales_group_offsets + group_pid , mask = group_pid < num_groups , other = 0
341+
342+ # Calculate this group's start row after blocked format padding, by doing a prefix sum
343+ # of each previous group's padded size.
344+ output_group_start_row = _blocked_group_start_idx (
345+ group_pid , orig_offsets , num_groups , 128
322346 )
347+
323348 # Calculate destination indices for each row and col in block swizzled layout.
324349 # We can reuse this swizzle transformation on each block of data we read.
325350 row_offs = tl .arange (0 , BLOCK_ROWS )[:, None ]
@@ -489,7 +514,6 @@ def triton_scale_swizzle_per_group_3d(
489514def triton_mx_block_rearrange_2d_K_groups (
490515 scales_tensor : torch .Tensor ,
491516 input_group_end_offsets : torch .Tensor ,
492- output_group_start_offsets : torch .Tensor ,
493517) -> torch .Tensor :
494518 """
495519 Rearranges an E8M0 tensor scale to block-scaled swizzle format on a per group basis,
@@ -538,13 +562,10 @@ def triton_mx_block_rearrange_2d_K_groups(
538562 rows ,
539563 cols ,
540564 padded_rows ,
541- num_groups ,
542- # Original offsets (to read from)
543565 input_group_end_offsets ,
544- # Output scales tensor and group offsets after padding (to write to)
545566 output .view (torch .uint8 ),
546- output_group_start_offsets ,
547567 output_stride_per_block ,
568+ num_groups = num_groups ,
548569 BLOCK_ROWS = BLOCK_ROWS ,
549570 BLOCK_COLS = BLOCK_COLS ,
550571 DEBUG = False ,
@@ -560,11 +581,10 @@ def triton_scale_swizzle_2d_K_groups(
560581 scale_rows ,
561582 scale_cols ,
562583 padded_rows ,
563- num_groups ,
564584 orig_offsets , # (num_groups,)
565585 output_scales_ptr ,
566- output_scales_group_offsets , # (num_groups,)
567586 output_stride_per_block ,
587+ num_groups : tl .constexpr ,
568588 BLOCK_ROWS : tl .constexpr ,
569589 BLOCK_COLS : tl .constexpr ,
570590 DEBUG : tl .constexpr = False ,
@@ -578,8 +598,11 @@ def triton_scale_swizzle_2d_K_groups(
578598 )
579599 input_group_end_col = tl .load (orig_offsets + group_pid )
580600
581- # Output scales start row we will begin writing to
582- output_group_start_col = tl .load (output_scales_group_offsets + group_pid )
601+ # Calculate this group's start row after blocked format padding, by doing a prefix sum
602+ # of each previous group's padded size.
603+ output_group_start_col = _blocked_group_start_idx (
604+ group_pid , orig_offsets , num_groups , 4
605+ )
583606
584607 row_offs = tl .arange (0 , BLOCK_ROWS )[:, None ]
585608 col_offs = tl .arange (0 , BLOCK_COLS )[None , :]
0 commit comments