@@ -85,7 +85,12 @@ def triton_fp8_per_group_rowwise_scales(
8585 n_groups = offsets .numel ()
8686
8787 # allocate on-device buffers for output and scales
88- output_buffer = torch .empty ((m , k ), dtype = output_dtype , device = hp_tensor .device )
88+ output_buffer = torch .empty (
89+ (m , k ), dtype = output_dtype , device = hp_tensor .device
90+ ).as_strided (
91+ (m , k ), # shape
92+ (1 , m ), # stride
93+ )
8994 scales_buffer = torch .empty (
9095 (m * n_groups ), dtype = torch .float32 , device = hp_tensor .device
9196 )
@@ -114,7 +119,7 @@ def triton_fp8_per_group_rowwise_scales(
114119 round_scales_to_power_of_2 ,
115120 EPS = EPS ,
116121 )
117- return output_buffer , scales_buffer
122+ return output_buffer . transpose ( - 2 , - 1 ). contiguous (). transpose ( - 2 , - 1 ) , scales_buffer
118123
119124
120125@triton_fp8_per_group_rowwise_scales .register_fake
@@ -336,8 +341,8 @@ def _triton_fp8_per_group_colwise_scales_kernel(
336341 offsets_ptr ,
337342 out_ptr ,
338343 scales_ptr ,
344+ M : int ,
339345 K : int ,
340- N : int ,
341346 stride_input_row : int ,
342347 stride_input_col : int ,
343348 stride_output_row : int ,
@@ -372,7 +377,7 @@ def _triton_fp8_per_group_colwise_scales_kernel(
372377 + block_col_offs [None , :] * stride_input_col
373378 )
374379 block_mask = (block_row_offs [:, None ] < group_row_end_idx ) & (
375- block_col_offs [None , :] < N
380+ block_col_offs [None , :] < K
376381 )
377382 data = tl .load (input_ptr + block_offs , mask = block_mask , other = 0.0 ).to (
378383 input_dtype
@@ -394,8 +399,8 @@ def _triton_fp8_per_group_colwise_scales_kernel(
394399 # store colwise scales for each group in contiguous memory:
395400 # [group0_col0, group_0_col1, ..., group2_col0, group2_col1]
396401 # note: input tensor is in col-major memory layout.
397- scales_offs = block_col_offs + (N * offset_idx )
398- scales_mask = tl .arange (0 , BLOCK_SIZE ) < N
402+ scales_offs = block_col_offs + (K * offset_idx )
403+ scales_mask = tl .arange (0 , BLOCK_SIZE ) < K
399404 tl .store (scales_ptr + scales_offs , scales , mask = scales_mask )
400405
401406 # perform float8 conversion for this group
@@ -406,7 +411,7 @@ def _triton_fp8_per_group_colwise_scales_kernel(
406411 + block_col_offs [None , :] * stride_input_col
407412 )
408413 block_mask = (block_row_offs [:, None ] < group_row_end_idx ) & (
409- block_col_offs [None , :] < N
414+ block_col_offs [None , :] < K
410415 )
411416 data = tl .load (input_ptr + block_offs , mask = block_mask , other = 0.0 ).to (
412417 input_dtype
0 commit comments