@@ -891,13 +891,13 @@ def _get_mxfp8_dim1_kernel_autotune_configs():
891891
892892 @triton .autotune (
893893 configs = _get_mxfp8_dim1_kernel_autotune_configs (),
894- key = ["n_rows" , " n_cols" , "INNER_BLOCK_SIZE" ],
894+ key = ["n_cols" , "INNER_BLOCK_SIZE" ],
895895 )
896896 @triton .jit
897897 def to_mxfp8_dim1_kernel (
898898 x_ptr , # pointer to input tensor
899899 output_col_major_ptr , # pointer to column-major output tensor (column-normalized)
900- col_scale_ptr , # pointer to store column-wise maximum absolute values
900+ col_scale_ptr , # pointer to store scales
901901 n_rows , # number of rows in the tensor
902902 n_cols , # number of columns in the tensor
903903 ROW_TILE_SIZE : tl .constexpr ,
@@ -1038,6 +1038,174 @@ def to_mxfp8_dim1_kernel(
10381038 # TODO(future): mask this store
10391039 tl .store (col_scale_start_ptr + col_scale_indices , col_scale_e8m0 )
10401040
1041+ @triton .autotune (
1042+ configs = _get_mxfp8_dim1_kernel_autotune_configs (),
1043+ key = ["n_cols" , "INNER_BLOCK_SIZE" ],
1044+ )
1045+ @triton .jit
1046+ def to_mxfp8_dim0_kernel (
1047+ x_ptr , # pointer to input tensor
1048+ output_ptr , # pointer to output tensor (row-normalized)
1049+ row_scale_ptr , # pointer to store row-wise maximum absolute values
1050+ n_rows , # number of rows in the tensor
1051+ n_cols , # number of columns in the tensor
1052+ ROW_TILE_SIZE : tl .constexpr ,
1053+ COL_TILE_SIZE : tl .constexpr ,
1054+ INNER_BLOCK_SIZE : tl .constexpr , # should be 32 for MX
1055+ ):
1056+ """
1057+ Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1058+
1059+ This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1060+ Instead of transposing and scaling across columns, this kernel scales across rows.
1061+ """
1062+
1063+ BLOCKS_PER_COL_TILE : tl .constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1064+
1065+ # Get program ID
1066+ pid_row = tl .program_id (0 )
1067+ pid_col = tl .program_id (1 )
1068+
1069+ # Calculate starting row and column for this tile
1070+ start_row = pid_row * ROW_TILE_SIZE
1071+ start_col = pid_col * COL_TILE_SIZE
1072+
1073+ # Create offsets for the block
1074+ row_offsets = tl .arange (0 , ROW_TILE_SIZE )
1075+ col_offsets = tl .arange (0 , COL_TILE_SIZE )
1076+
1077+ # Compute global row/col positions
1078+ rows = start_row + row_offsets [:, None ]
1079+ cols = start_col + col_offsets [None , :]
1080+
1081+ # Create masks for out-of-bounds accesses
1082+ row_mask = rows < n_rows
1083+ col_mask = cols < n_cols
1084+ mask = row_mask & col_mask
1085+
1086+ # Compute memory offsets for row-major layout (rows, cols)
1087+ row_major_offsets = (rows * n_cols + cols ).to (tl .int32 )
1088+
1089+ # Load the entire block in a single operation
1090+ # shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1091+ x_block = tl .load (x_ptr + row_major_offsets , mask = mask )
1092+
1093+ # Reshape to inner tile size for rowwise scaling
1094+ # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1095+ x_block_r = x_block .reshape (
1096+ ROW_TILE_SIZE * BLOCKS_PER_COL_TILE , INNER_BLOCK_SIZE
1097+ )
1098+
1099+ # Calculate the absolute values of elements in the block
1100+ x_block_abs_r = tl .abs (x_block_r )
1101+
1102+ # Find the maximum absolute value for each row (across columns)
1103+ # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1104+ row_scale_r , row_scale_e8m0_r = _triton_calculate_scale (x_block_abs_r , axis = 1 )
1105+
1106+ # Divide each row by scale
1107+ # Broadcasting row_scale to match x_block's shape
1108+ # x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1109+ # row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1110+ row_normalized_r = x_block_r / row_scale_r [:, None ]
1111+
1112+ # Reshape back to original tile size
1113+ row_normalized = tl .reshape (row_normalized_r , ROW_TILE_SIZE , COL_TILE_SIZE )
1114+
1115+ # Quantize to float8
1116+ row_normalized = row_normalized .to (tl .float8e4nv )
1117+
1118+ # Store the row-normalized result in row-major format
1119+ tl .store (output_ptr + row_major_offsets , row_normalized , mask = mask )
1120+
1121+ # For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
1122+ # Calculate base offset for this tile's scales
1123+ scales_per_row = n_cols // INNER_BLOCK_SIZE
1124+
1125+ # Create row and column indices for scale storage
1126+ scale_row_indices = tl .arange (0 , ROW_TILE_SIZE )[:, None ] + (
1127+ pid_row * ROW_TILE_SIZE
1128+ )
1129+ scale_col_indices = tl .arange (0 , BLOCKS_PER_COL_TILE )[None , :] + (
1130+ pid_col * BLOCKS_PER_COL_TILE
1131+ )
1132+
1133+ # Calculate linear indices into scale tensor
1134+ scale_offsets = scale_row_indices * scales_per_row + scale_col_indices
1135+
1136+ # Create masks for valid scale indices
1137+ scale_row_mask = scale_row_indices < n_rows
1138+ scale_col_mask = scale_col_indices < scales_per_row
1139+ scale_mask = scale_row_mask & scale_col_mask
1140+
1141+ # Reshape scale values and masks to match the flattened layout
1142+ row_scale_e8m0_2d = row_scale_e8m0_r .reshape (ROW_TILE_SIZE , BLOCKS_PER_COL_TILE )
1143+
1144+ # Store the scales with proper masking
1145+ tl .store (row_scale_ptr + scale_offsets , row_scale_e8m0_2d , mask = scale_mask )
1146+
1147+ @triton_op ("torchao::triton_to_mxfp8_dim0" , mutates_args = {})
1148+ def triton_to_mxfp8_dim0 (
1149+ x : torch .Tensor , inner_block_size : int = 32
1150+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1151+ """
1152+ Input:
1153+ * `x` - input tensor, in row major memory layout
1154+ * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
1155+
1156+ Output:
1157+ * `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1158+ * `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1159+ """
1160+ assert x .is_contiguous (), "`x` must be contiguous"
1161+ assert inner_block_size <= 32
1162+
1163+ # Reshape tensor to 2d if necessary and get shape
1164+ x_orig_shape = x .shape
1165+ x = x .reshape (- 1 , x .shape [- 1 ])
1166+ n_rows , n_cols = x .shape
1167+
1168+ assert n_cols % inner_block_size == 0 , (
1169+ "columns must be divisible by inner block size"
1170+ )
1171+
1172+ # Create output tensors
1173+ output = torch .empty (
1174+ (n_rows , n_cols ), dtype = torch .float8_e4m3fn , device = x .device
1175+ )
1176+
1177+ # Create scale tensors for rowwise scaling
1178+ row_scale = torch .empty (
1179+ (n_rows , n_cols // inner_block_size ),
1180+ dtype = torch .uint8 ,
1181+ device = x .device ,
1182+ )
1183+
1184+ # Calculate grid dimensions based on tile size
1185+ grid = lambda META : (
1186+ triton .cdiv (n_rows , META ["ROW_TILE_SIZE" ]),
1187+ triton .cdiv (n_cols , META ["COL_TILE_SIZE" ]),
1188+ )
1189+
1190+ # Launch the kernel
1191+ wrap_triton (to_mxfp8_dim0_kernel )[grid ](
1192+ x_ptr = x ,
1193+ output_ptr = output ,
1194+ row_scale_ptr = row_scale ,
1195+ n_rows = n_rows ,
1196+ n_cols = n_cols ,
1197+ INNER_BLOCK_SIZE = inner_block_size ,
1198+ )
1199+
1200+ # Reshape output back to original shape
1201+ output = output .reshape (x_orig_shape )
1202+ row_scale = row_scale .reshape (* x_orig_shape [:- 1 ], row_scale .shape [- 1 ])
1203+
1204+ return (
1205+ output ,
1206+ row_scale .view (torch .float8_e8m0fnu ),
1207+ )
1208+
10411209 @triton_op ("torchao::triton_to_mxfp8_dim1" , mutates_args = {})
10421210 def triton_to_mxfp8_dim1 (
10431211 x : torch .Tensor , inner_block_size : int = 32
@@ -1467,6 +1635,12 @@ def _(scale_tensor):
14671635 return scale_tensor .new_empty ((padded_rows , padded_cols ))
14681636else :
14691637
1638+ def triton_to_mxfp8_dim0 (
1639+ x : torch .Tensor ,
1640+ inner_block_size = 32 ,
1641+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1642+ raise AssertionError ("needs torch version 2.8+ and triton" )
1643+
14701644 def triton_to_mxfp8_dim1 (
14711645 x , inner_block_size = 32
14721646 ) -> Tuple [torch .Tensor , torch .Tensor ]:
0 commit comments