diff --git a/torchao/prototype/moe_training/kernels/mxfp8/comms.py b/torchao/prototype/moe_training/kernels/mxfp8/comms.py index 7c6999fbf1..d302e4e244 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/comms.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/comms.py @@ -11,7 +11,7 @@ blockwise_barrier, sync_threads, ) -from torchao.prototype.mx_formats.config import ScaleCalculationMode +from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0 from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx @@ -473,11 +473,9 @@ def forward( """ # Quantize input block_size = 32 - input_scales, input_data = to_mx( + input_data, input_scales = triton_to_mxfp8_dim0( input, - elem_dtype=torch.float8_e4m3fn, - block_size=block_size, - scaling_mode=ScaleCalculationMode.RCEIL, + inner_block_size=block_size, ) # Dispatch data (async) @@ -529,11 +527,9 @@ def backward(ctx, grad_output_hp): # Quantize grad_output block_size = 32 - grad_out_scales, grad_out_data = to_mx( + grad_out_data, grad_out_scales = triton_to_mxfp8_dim0( grad_output_hp, - elem_dtype=torch.float8_e4m3fn, - block_size=block_size, - scaling_mode=ScaleCalculationMode.RCEIL, + inner_block_size=block_size, ) # Dispatch data (async) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index e14a79e774..6321e1a37d 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1040,7 +1040,7 @@ def to_mxfp8_dim1_kernel( @triton.autotune( configs=_get_mxfp8_dim1_kernel_autotune_configs(), - key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"], + key=["n_cols", "INNER_BLOCK_SIZE"], ) @triton.jit def to_mxfp8_dim0_kernel( @@ -1118,33 +1118,31 @@ def to_mxfp8_dim0_kernel( # Store the row-normalized result in row-major format tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask) - # reshape row_scale_e8m0_r for proper storage - # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) - row_scale_e8m0 = row_scale_e8m0_r.reshape(ROW_TILE_SIZE * BLOCKS_PER_COL_TILE) + # For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE) + # Calculate base offset for this tile's scales + scales_per_row = n_cols // INNER_BLOCK_SIZE - row_scale_start_offsets = ( - (pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE)) - * BLOCKS_PER_COL_TILE # number of blocks seen so far - + pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE + # Create row and column indices for scale storage + scale_row_indices = tl.arange(0, ROW_TILE_SIZE)[:, None] + ( + pid_row * ROW_TILE_SIZE + ) + scale_col_indices = tl.arange(0, BLOCKS_PER_COL_TILE)[None, :] + ( + pid_col * BLOCKS_PER_COL_TILE ) - row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets - - # calculate row_scale_indices - row_scale_indices = tl.arange(0, ROW_TILE_SIZE * BLOCKS_PER_COL_TILE) + # Calculate linear indices into scale tensor + scale_offsets = scale_row_indices * scales_per_row + scale_col_indices - # How many values are in all the other rows for this col_pid, need to jump - # over them for every BLOCKS_PER_COL_TILE values - jump_vals_per_row = (n_cols - COL_TILE_SIZE) // INNER_BLOCK_SIZE + # Create masks for valid scale indices + scale_row_mask = scale_row_indices < n_rows + scale_col_mask = scale_col_indices < scales_per_row + scale_mask = scale_row_mask & scale_col_mask - # example transformation (specifics depend on tile sizes): - # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] - row_scale_indices = row_scale_indices + ( - (row_scale_indices // BLOCKS_PER_COL_TILE) * jump_vals_per_row - ) + # Reshape scale values and masks to match the flattened layout + row_scale_e8m0_2d = row_scale_e8m0_r.reshape(ROW_TILE_SIZE, BLOCKS_PER_COL_TILE) - # Store the scales - tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0) + # Store the scales with proper masking + tl.store(row_scale_ptr + scale_offsets, row_scale_e8m0_2d, mask=scale_mask) @triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={}) def triton_to_mxfp8_dim0( @@ -1167,14 +1165,9 @@ def triton_to_mxfp8_dim0( x = x.reshape(-1, x.shape[-1]) n_rows, n_cols = x.shape - # Masking of loads and stores is not well tested yet, so for now enforce - # shapes which do not need masking. Note that this condition depends on max values of - # ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above. - # TODO(future): implement and test masking and remove this restriction - max_row_tile_size = 128 - max_col_tile_size = 128 - assert n_rows % max_row_tile_size == 0, "unsupported" - assert n_cols % max_col_tile_size == 0, "unsupported" + assert n_cols % inner_block_size == 0, ( + "columns must be divisible by inner block size" + ) # Create output tensors output = torch.empty(