@@ -1040,7 +1040,7 @@ def to_mxfp8_dim1_kernel(
10401040
10411041 @triton .autotune (
10421042 configs = _get_mxfp8_dim1_kernel_autotune_configs (),
1043- key = ["n_rows" , " n_cols" , "INNER_BLOCK_SIZE" ],
1043+ key = ["n_cols" , "INNER_BLOCK_SIZE" ],
10441044 )
10451045 @triton .jit
10461046 def to_mxfp8_dim0_kernel (
@@ -1118,33 +1118,31 @@ def to_mxfp8_dim0_kernel(
11181118 # Store the row-normalized result in row-major format
11191119 tl .store (output_ptr + row_major_offsets , row_normalized , mask = mask )
11201120
1121- # reshape row_scale_e8m0_r for proper storage
1122- # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1123- row_scale_e8m0 = row_scale_e8m0_r . reshape ( ROW_TILE_SIZE * BLOCKS_PER_COL_TILE )
1121+ # For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
1122+ # Calculate base offset for this tile's scales
1123+ scales_per_row = n_cols // INNER_BLOCK_SIZE
11241124
1125- row_scale_start_offsets = (
1126- (pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE ))
1127- * BLOCKS_PER_COL_TILE # number of blocks seen so far
1128- + pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE
1125+ # Create row and column indices for scale storage
1126+ scale_row_indices = tl .arange (0 , ROW_TILE_SIZE )[:, None ] + (
1127+ pid_row * ROW_TILE_SIZE
1128+ )
1129+ scale_col_indices = tl .arange (0 , BLOCKS_PER_COL_TILE )[None , :] + (
1130+ pid_col * BLOCKS_PER_COL_TILE
11291131 )
11301132
1131- row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets
1132-
1133- # calculate row_scale_indices
1134- row_scale_indices = tl .arange (0 , ROW_TILE_SIZE * BLOCKS_PER_COL_TILE )
1133+ # Calculate linear indices into scale tensor
1134+ scale_offsets = scale_row_indices * scales_per_row + scale_col_indices
11351135
1136- # How many values are in all the other rows for this col_pid, need to jump
1137- # over them for every BLOCKS_PER_COL_TILE values
1138- jump_vals_per_row = (n_cols - COL_TILE_SIZE ) // INNER_BLOCK_SIZE
1136+ # Create masks for valid scale indices
1137+ scale_row_mask = scale_row_indices < n_rows
1138+ scale_col_mask = scale_col_indices < scales_per_row
1139+ scale_mask = scale_row_mask & scale_col_mask
11391140
1140- # example transformation (specifics depend on tile sizes):
1141- # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
1142- row_scale_indices = row_scale_indices + (
1143- (row_scale_indices // BLOCKS_PER_COL_TILE ) * jump_vals_per_row
1144- )
1141+ # Reshape scale values and masks to match the flattened layout
1142+ row_scale_e8m0_2d = row_scale_e8m0_r .reshape (ROW_TILE_SIZE , BLOCKS_PER_COL_TILE )
11451143
1146- # Store the scales
1147- tl .store (row_scale_start_ptr + row_scale_indices , row_scale_e8m0 )
1144+ # Store the scales with proper masking
1145+ tl .store (row_scale_ptr + scale_offsets , row_scale_e8m0_2d , mask = scale_mask )
11481146
11491147 @triton_op ("torchao::triton_to_mxfp8_dim0" , mutates_args = {})
11501148 def triton_to_mxfp8_dim0 (
@@ -1167,14 +1165,9 @@ def triton_to_mxfp8_dim0(
11671165 x = x .reshape (- 1 , x .shape [- 1 ])
11681166 n_rows , n_cols = x .shape
11691167
1170- # Masking of loads and stores is not well tested yet, so for now enforce
1171- # shapes which do not need masking. Note that this condition depends on max values of
1172- # ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above.
1173- # TODO(future): implement and test masking and remove this restriction
1174- max_row_tile_size = 128
1175- max_col_tile_size = 128
1176- assert n_rows % max_row_tile_size == 0 , "unsupported"
1177- assert n_cols % max_col_tile_size == 0 , "unsupported"
1168+ assert n_cols % inner_block_size == 0 , (
1169+ "columns must be divisible by inner block size"
1170+ )
11781171
11791172 # Create output tensors
11801173 output = torch .empty (
0 commit comments