@@ -829,6 +829,8 @@ def _(uint8_data):
829829 import triton .language as tl
830830 from torch .library import triton_op , wrap_triton
831831
832+ print ("importing triton ops" )
833+
832834 @triton .jit
833835 def _triton_calculate_scale (x , axis ):
834836 # There is no good support for accessing globals from a jit'ed triton
@@ -891,13 +893,13 @@ def _get_mxfp8_dim1_kernel_autotune_configs():
891893
892894 @triton .autotune (
893895 configs = _get_mxfp8_dim1_kernel_autotune_configs (),
894- key = ["n_rows" , " n_cols" , "INNER_BLOCK_SIZE" ],
896+ key = ["n_cols" , "INNER_BLOCK_SIZE" ],
895897 )
896898 @triton .jit
897899 def to_mxfp8_dim1_kernel (
898900 x_ptr , # pointer to input tensor
899901 output_col_major_ptr , # pointer to column-major output tensor (column-normalized)
900- col_scale_ptr , # pointer to store column-wise maximum absolute values
902+ col_scale_ptr , # pointer to store scales
901903 n_rows , # number of rows in the tensor
902904 n_cols , # number of columns in the tensor
903905 ROW_TILE_SIZE : tl .constexpr ,
@@ -1038,6 +1040,175 @@ def to_mxfp8_dim1_kernel(
10381040 # TODO(future): mask this store
10391041 tl .store (col_scale_start_ptr + col_scale_indices , col_scale_e8m0 )
10401042
1043+ @triton .autotune (
1044+ configs = _get_mxfp8_dim1_kernel_autotune_configs (),
1045+ key = ["n_rows" , "n_cols" , "INNER_BLOCK_SIZE" ],
1046+ )
1047+ @triton .jit
1048+ def to_mxfp8_dim0_kernel (
1049+ x_ptr ,
1050+ output_ptr ,
1051+ row_scale_ptr ,
1052+ n_rows ,
1053+ n_cols ,
1054+ ROW_TILE_SIZE : tl .constexpr ,
1055+ COL_TILE_SIZE : tl .constexpr ,
1056+ INNER_BLOCK_SIZE : tl .constexpr , # should be 32 for MX
1057+ ):
1058+ """
1059+ Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1060+
1061+ This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1062+ Instead of transposing and scaling across columns, this kernel scales across rows.
1063+ """
1064+
1065+ BLOCKS_PER_COL_TILE : tl .constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1066+
1067+ # Get program ID
1068+ pid_row = tl .program_id (0 )
1069+ pid_col = tl .program_id (1 )
1070+
1071+ # Calculate starting row and column for this tile
1072+ start_row = pid_row * ROW_TILE_SIZE
1073+ start_col = pid_col * COL_TILE_SIZE
1074+
1075+ # Create offsets for the block
1076+ row_offsets = tl .arange (0 , ROW_TILE_SIZE )
1077+ col_offsets = tl .arange (0 , COL_TILE_SIZE )
1078+
1079+ # Compute global row/col positions
1080+ rows = start_row + row_offsets [:, None ]
1081+ cols = start_col + col_offsets [None , :]
1082+
1083+ # Create masks for out-of-bounds accesses
1084+ row_mask = rows < n_rows
1085+ col_mask = cols < n_cols
1086+ mask = row_mask & col_mask
1087+
1088+ # Compute memory offsets for row-major layout (rows, cols)
1089+ row_major_offsets = (rows * n_cols + cols ).to (tl .int32 )
1090+
1091+ # Load the entire block in a single operation
1092+ # shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1093+ x_block = tl .load (x_ptr + row_major_offsets , mask = mask )
1094+
1095+ # Reshape to inner tile size for rowwise scaling
1096+ # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1097+ x_block_r = x_block .reshape (
1098+ ROW_TILE_SIZE * BLOCKS_PER_COL_TILE , INNER_BLOCK_SIZE
1099+ )
1100+
1101+ # Calculate the absolute values of elements in the block
1102+ x_block_abs_r = tl .abs (x_block_r )
1103+
1104+ # Find the maximum absolute value for each row (across columns)
1105+ # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1106+ row_scale_r , row_scale_e8m0_r = _triton_calculate_scale (x_block_abs_r , axis = 1 )
1107+
1108+ # Divide each row by scale
1109+ # Broadcasting row_scale to match x_block's shape
1110+ # x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1111+ # row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1112+ row_normalized_r = x_block_r / row_scale_r [:, None ]
1113+
1114+ # Reshape back to original tile size
1115+ row_normalized = tl .reshape (row_normalized_r , ROW_TILE_SIZE , COL_TILE_SIZE )
1116+
1117+ # Quantize to float8
1118+ row_normalized = row_normalized .to (tl .float8e4nv )
1119+
1120+ # Store the row-normalized result in row-major format
1121+ tl .store (output_ptr + row_major_offsets , row_normalized , mask = mask )
1122+
1123+ # reshape row_scale_e8m0_r for proper storage
1124+ # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1125+ row_scale_e8m0 = row_scale_e8m0_r .reshape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE )
1126+
1127+ row_scale_start_offsets = (
1128+ (pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE ))
1129+ * BLOCKS_PER_COL_TILE # number of blocks seen so far
1130+ + pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE
1131+ )
1132+
1133+ row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets
1134+
1135+ # calculate row_scale_indices
1136+ row_scale_indices = tl .arange (0 , ROW_TILE_SIZE * BLOCKS_PER_COL_TILE )
1137+
1138+ # How many values are in all the other rows for this col_pid, need to jump
1139+ # over them for every BLOCKS_PER_COL_TILE values
1140+ jump_vals_per_row = (n_cols - COL_TILE_SIZE ) // INNER_BLOCK_SIZE
1141+
1142+ # example transformation (specifics depend on tile sizes):
1143+ # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
1144+ row_scale_indices = row_scale_indices + (
1145+ (row_scale_indices // BLOCKS_PER_COL_TILE ) * jump_vals_per_row
1146+ )
1147+
1148+ # Store the scales
1149+ tl .store (row_scale_start_ptr + row_scale_indices , row_scale_e8m0 )
1150+
1151+ @triton_op ("torchao::triton_to_mxfp8_dim0" , mutates_args = {})
1152+ def triton_to_mxfp8_dim0 (
1153+ x : torch .Tensor , inner_block_size : int = 32
1154+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1155+ """
1156+ Input:
1157+ * `x` - input tensor, in row major memory layout
1158+ * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
1159+
1160+ Output:
1161+ * `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1162+ * `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1163+ """
1164+ assert x .is_contiguous (), "`x` must be contiguous"
1165+ assert inner_block_size <= 32
1166+
1167+ # Get tensor shape
1168+ n_rows , n_cols = x .shape
1169+
1170+ # Masking of loads and stores is not well tested yet, so for now enforce
1171+ # shapes which do not need masking. Note that this condition depends on max values of
1172+ # ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above.
1173+ # TODO(future): implement and test masking and remove this restriction
1174+ max_row_tile_size = 128
1175+ max_col_tile_size = 128
1176+ assert n_rows % max_row_tile_size == 0 , "unsupported"
1177+ assert n_cols % max_col_tile_size == 0 , "unsupported"
1178+
1179+ # Create output tensors
1180+ output = torch .empty (
1181+ (n_rows , n_cols ), dtype = torch .float8_e4m3fn , device = x .device
1182+ )
1183+
1184+ # Create scale tensors for rowwise scaling
1185+ row_scale = torch .empty (
1186+ (n_rows , n_cols // inner_block_size , 1 ),
1187+ dtype = torch .uint8 ,
1188+ device = x .device ,
1189+ )
1190+
1191+ # Calculate grid dimensions based on tile size
1192+ grid = lambda META : (
1193+ triton .cdiv (n_rows , META ["ROW_TILE_SIZE" ]),
1194+ triton .cdiv (n_cols , META ["COL_TILE_SIZE" ]),
1195+ )
1196+
1197+ # Launch the kernel
1198+ wrap_triton (to_mxfp8_dim0_kernel )[grid ](
1199+ x_ptr = x ,
1200+ output_ptr = output ,
1201+ row_scale_ptr = row_scale ,
1202+ n_rows = n_rows ,
1203+ n_cols = n_cols ,
1204+ INNER_BLOCK_SIZE = inner_block_size ,
1205+ )
1206+
1207+ return (
1208+ output ,
1209+ row_scale .view (torch .float8_e8m0fnu ),
1210+ )
1211+
10411212 @triton_op ("torchao::triton_to_mxfp8_dim1" , mutates_args = {})
10421213 def triton_to_mxfp8_dim1 (
10431214 x : torch .Tensor , inner_block_size : int = 32
@@ -1459,6 +1630,12 @@ def _(scale_tensor):
14591630 return scale_tensor .new_empty ((padded_rows , padded_cols ))
14601631else :
14611632
1633+ def triton_to_mxfp8_dim0 (
1634+ x : torch .Tensor ,
1635+ inner_block_size = 32 ,
1636+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1637+ raise AssertionError ("needs torch version 2.8+ and triton" )
1638+
14621639 def triton_to_mxfp8_dim1 (
14631640 x , inner_block_size = 32
14641641 ) -> Tuple [torch .Tensor , torch .Tensor ]:
0 commit comments