@@ -829,6 +829,8 @@ def _(uint8_data):
829829 import triton .language as tl
830830 from torch .library import triton_op , wrap_triton
831831
832+ print ("importing triton ops" )
833+
832834 @triton .jit
833835 def _triton_calculate_scale (x , axis ):
834836 # There is no good support for accessing globals from a jit'ed triton
@@ -891,13 +893,13 @@ def _get_mxfp8_dim1_kernel_autotune_configs():
891893
892894 @triton .autotune (
893895 configs = _get_mxfp8_dim1_kernel_autotune_configs (),
894- key = ["n_rows" , " n_cols" , "INNER_BLOCK_SIZE" ],
896+ key = ["n_cols" , "INNER_BLOCK_SIZE" ],
895897 )
896898 @triton .jit
897899 def to_mxfp8_dim1_kernel (
898900 x_ptr , # pointer to input tensor
899901 output_col_major_ptr , # pointer to column-major output tensor (column-normalized)
900- col_scale_ptr , # pointer to store column-wise maximum absolute values
902+ col_scale_ptr , # pointer to store scales
901903 n_rows , # number of rows in the tensor
902904 n_cols , # number of columns in the tensor
903905 ROW_TILE_SIZE : tl .constexpr ,
@@ -1038,6 +1040,174 @@ def to_mxfp8_dim1_kernel(
10381040 # TODO(future): mask this store
10391041 tl .store (col_scale_start_ptr + col_scale_indices , col_scale_e8m0 )
10401042
1043+ @triton .autotune (
1044+ configs = _get_mxfp8_dim1_kernel_autotune_configs (),
1045+ key = ["n_cols" , "INNER_BLOCK_SIZE" ],
1046+ )
1047+ @triton .jit
1048+ def to_mxfp8_dim0_kernel (
1049+ x_ptr , # pointer to input tensor
1050+ output_ptr , # pointer to output tensor (row-normalized)
1051+ row_scale_ptr , # pointer to store row-wise maximum absolute values
1052+ n_rows , # number of rows in the tensor
1053+ n_cols , # number of columns in the tensor
1054+ ROW_TILE_SIZE : tl .constexpr ,
1055+ COL_TILE_SIZE : tl .constexpr ,
1056+ INNER_BLOCK_SIZE : tl .constexpr , # should be 32 for MX
1057+ ):
1058+ """
1059+ Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1060+
1061+ This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1062+ Instead of transposing and scaling across columns, this kernel scales across rows.
1063+ """
1064+
1065+ BLOCKS_PER_COL_TILE : tl .constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1066+
1067+ # Get program ID
1068+ pid_row = tl .program_id (0 )
1069+ pid_col = tl .program_id (1 )
1070+
1071+ # Calculate starting row and column for this tile
1072+ start_row = pid_row * ROW_TILE_SIZE
1073+ start_col = pid_col * COL_TILE_SIZE
1074+
1075+ # Create offsets for the block
1076+ row_offsets = tl .arange (0 , ROW_TILE_SIZE )
1077+ col_offsets = tl .arange (0 , COL_TILE_SIZE )
1078+
1079+ # Compute global row/col positions
1080+ rows = start_row + row_offsets [:, None ]
1081+ cols = start_col + col_offsets [None , :]
1082+
1083+ # Create masks for out-of-bounds accesses
1084+ row_mask = rows < n_rows
1085+ col_mask = cols < n_cols
1086+ mask = row_mask & col_mask
1087+
1088+ # Compute memory offsets for row-major layout (rows, cols)
1089+ row_major_offsets = (rows * n_cols + cols ).to (tl .int32 )
1090+
1091+ # Load the entire block in a single operation
1092+ # shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1093+ x_block = tl .load (x_ptr + row_major_offsets , mask = mask )
1094+
1095+ # Reshape to inner tile size for rowwise scaling
1096+ # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1097+ x_block_r = x_block .reshape (
1098+ ROW_TILE_SIZE * BLOCKS_PER_COL_TILE , INNER_BLOCK_SIZE
1099+ )
1100+
1101+ # Calculate the absolute values of elements in the block
1102+ x_block_abs_r = tl .abs (x_block_r )
1103+
1104+ # Find the maximum absolute value for each row (across columns)
1105+ # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1106+ row_scale_r , row_scale_e8m0_r = _triton_calculate_scale (x_block_abs_r , axis = 1 )
1107+
1108+ # Divide each row by scale
1109+ # Broadcasting row_scale to match x_block's shape
1110+ # x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1111+ # row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1112+ row_normalized_r = x_block_r / row_scale_r [:, None ]
1113+
1114+ # Reshape back to original tile size
1115+ row_normalized = tl .reshape (row_normalized_r , ROW_TILE_SIZE , COL_TILE_SIZE )
1116+
1117+ # Quantize to float8
1118+ row_normalized = row_normalized .to (tl .float8e4nv )
1119+
1120+ # Store the row-normalized result in row-major format
1121+ tl .store (output_ptr + row_major_offsets , row_normalized , mask = mask )
1122+
1123+ # For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
1124+ # Calculate base offset for this tile's scales
1125+ scales_per_row = n_cols // INNER_BLOCK_SIZE
1126+
1127+ # Create row and column indices for scale storage
1128+ scale_row_indices = tl .arange (0 , ROW_TILE_SIZE )[:, None ] + (
1129+ pid_row * ROW_TILE_SIZE
1130+ )
1131+ scale_col_indices = tl .arange (0 , BLOCKS_PER_COL_TILE )[None , :] + (
1132+ pid_col * BLOCKS_PER_COL_TILE
1133+ )
1134+
1135+ # Calculate linear indices into scale tensor
1136+ scale_offsets = scale_row_indices * scales_per_row + scale_col_indices
1137+
1138+ # Create masks for valid scale indices
1139+ scale_row_mask = scale_row_indices < n_rows
1140+ scale_col_mask = scale_col_indices < scales_per_row
1141+ scale_mask = scale_row_mask & scale_col_mask
1142+
1143+ # Reshape scale values and masks to match the flattened layout
1144+ row_scale_e8m0_2d = row_scale_e8m0_r .reshape (ROW_TILE_SIZE , BLOCKS_PER_COL_TILE )
1145+
1146+ # Store the scales with proper masking
1147+ tl .store (row_scale_ptr + scale_offsets , row_scale_e8m0_2d , mask = scale_mask )
1148+
1149+ @triton_op ("torchao::triton_to_mxfp8_dim0" , mutates_args = {})
1150+ def triton_to_mxfp8_dim0 (
1151+ x : torch .Tensor , inner_block_size : int = 32
1152+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1153+ """
1154+ Input:
1155+ * `x` - input tensor, in row major memory layout
1156+ * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
1157+
1158+ Output:
1159+ * `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1160+ * `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1161+ """
1162+ assert x .is_contiguous (), "`x` must be contiguous"
1163+ assert inner_block_size <= 32
1164+
1165+ # Reshape tensor to 2d if necessary and get shape
1166+ x_orig_shape = x .shape
1167+ x = x .reshape (- 1 , x .shape [- 1 ])
1168+ n_rows , n_cols = x .shape
1169+
1170+ assert n_cols % inner_block_size == 0 , (
1171+ "columns must be divisible by inner block size"
1172+ )
1173+
1174+ # Create output tensors
1175+ output = torch .empty (
1176+ (n_rows , n_cols ), dtype = torch .float8_e4m3fn , device = x .device
1177+ )
1178+
1179+ # Create scale tensors for rowwise scaling
1180+ row_scale = torch .empty (
1181+ (n_rows , n_cols // inner_block_size ),
1182+ dtype = torch .uint8 ,
1183+ device = x .device ,
1184+ )
1185+
1186+ # Calculate grid dimensions based on tile size
1187+ grid = lambda META : (
1188+ triton .cdiv (n_rows , META ["ROW_TILE_SIZE" ]),
1189+ triton .cdiv (n_cols , META ["COL_TILE_SIZE" ]),
1190+ )
1191+
1192+ # Launch the kernel
1193+ wrap_triton (to_mxfp8_dim0_kernel )[grid ](
1194+ x_ptr = x ,
1195+ output_ptr = output ,
1196+ row_scale_ptr = row_scale ,
1197+ n_rows = n_rows ,
1198+ n_cols = n_cols ,
1199+ INNER_BLOCK_SIZE = inner_block_size ,
1200+ )
1201+
1202+ # Reshape output back to original shape
1203+ output = output .reshape (x_orig_shape )
1204+ row_scale = row_scale .reshape (* x_orig_shape [:- 1 ], row_scale .shape [- 1 ])
1205+
1206+ return (
1207+ output ,
1208+ row_scale .view (torch .float8_e8m0fnu ),
1209+ )
1210+
10411211 @triton_op ("torchao::triton_to_mxfp8_dim1" , mutates_args = {})
10421212 def triton_to_mxfp8_dim1 (
10431213 x : torch .Tensor , inner_block_size : int = 32
@@ -1459,6 +1629,12 @@ def _(scale_tensor):
14591629 return scale_tensor .new_empty ((padded_rows , padded_cols ))
14601630else :
14611631
1632+ def triton_to_mxfp8_dim0 (
1633+ x : torch .Tensor ,
1634+ inner_block_size = 32 ,
1635+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1636+ raise AssertionError ("needs torch version 2.8+ and triton" )
1637+
14621638 def triton_to_mxfp8_dim1 (
14631639 x , inner_block_size = 32
14641640 ) -> Tuple [torch .Tensor , torch .Tensor ]:
0 commit comments