diff --git a/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py b/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py index 9c49033a9d..2e0ea9434f 100644 --- a/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py +++ b/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py @@ -216,12 +216,12 @@ def bench_fp8_rowwise_grouped_mm(A, B_t, offs) -> float: def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float: - # A_mx shape: (M, K) - # A_scale shape: (M, K//block_size) + # A_fp8 shape: (M, K) + # A_scales shape: (M, K//block_size) A_scales, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size) - # B_mx shape: (E, N, K) - # B_scale shape: (E, N, K//block_size) + # B_fp8 shape: (E, N, K) + # B_scales shape: (E, N, K//block_size) B_scales, B_fp8 = to_mx( B_t.transpose(-2, -1), elem_dtype=torch.float8_e4m3fn, @@ -230,26 +230,19 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float: # Convert scales for each group to blocked format. Mg, K = A_fp8.shape - A_scales_blocked, starting_row_after_padding = torch_to_blocked_2d_M_groups( - A_scales, offs, K + A_scales_blocked, _ = torch_to_blocked_2d_M_groups( + A_scales, offs, block_size=block_size ) B_scales_blocked = torch_to_blocked_per_group_3d(B_scales) - # From this, we compute `group_sizes` and `starting_row_after_padding`: - # group_sizes = [32, 32, 64] - # starting_row_after_padding = [0, 32, 64, 128] - zero = torch.tensor([0], dtype=offs.dtype, device=offs.device) - group_sizes = torch.diff(offs, prepend=zero).to(torch.int64) - # Run the grouped mm mxfp8_us = benchmark_cuda_function_in_microseconds( - torch.ops.fbgemm.mx8mx8bf16_grouped_stacked, + torch._scaled_grouped_mm, A_fp8, - B_fp8, + B_fp8.transpose(-2, -1), A_scales_blocked, B_scales_blocked, - group_sizes, - starting_row_after_padding=starting_row_after_padding, + offs=offs, ) return mxfp8_us diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py index 8854135131..9e544a9059 100644 --- a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py +++ b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py @@ -51,6 +51,7 @@ def get_configs() -> List[ExperimentConfig]: block_size = 32 input_shapes = [ (16640, 5120 // block_size), + (131072, 5120 // block_size), ] num_groups = [16] configs = [] @@ -78,18 +79,21 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: ) Mg, K = input_shape - input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32) + block_size = 32 + input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=block_size) # bench torch compiled_run_torch = torch.compile(torch_to_blocked_2d_M_groups) torch_out_scales, torch_group_offs = compiled_run_torch( - input_tensor, input_group_offsets, K + input_tensor, + input_group_offsets, + block_size=block_size, ) torch_time_us = benchmark_cuda_function_in_microseconds( compiled_run_torch, input_tensor, input_group_offsets, - K, + block_size=block_size, ) # bench triton diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index 0db8dcf899..ecd4cefe6a 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -230,7 +230,7 @@ def test_triton_mx_block_rearrange_2d_M_groups( # torch reference ref_out_scales, _ = torch_to_blocked_2d_M_groups( - e8m0_scales, input_group_offsets, k, block_size=block_size + e8m0_scales, input_group_offsets, block_size=block_size ) # triton kernel diff --git a/torchao/prototype/moe_training/kernels/mxfp8/quant.py b/torchao/prototype/moe_training/kernels/mxfp8/quant.py index 60e469d318..24915d6359 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/quant.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/quant.py @@ -15,7 +15,7 @@ def torch_to_blocked_2d_M_groups( - x_scales: Tensor, group_offs: Tensor, K: int, block_size: int = 32 + x_scales: Tensor, group_offs: Tensor, block_size: int = 32 ) -> Tuple[Tensor, Tensor]: """ Convert scales to blocked format for a 2D tensor (input activations / token groups), @@ -34,14 +34,14 @@ def torch_to_blocked_2d_M_groups( assert x_scales.ndim == 2, "x_scales must be 2D" assert block_size == 32, "Only block_size=32 is supported for now" - total_M, _ = x_scales.shape + total_M, scale_cols = x_scales.shape num_groups = group_offs.shape[0] # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group, # the Triton kernenl will use an upper bound of adding 128 padding rows to each group. # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl). total_M_padded = total_M + num_groups * 128 - blocked_scales = x_scales.new_zeros(total_M_padded, K // block_size) + blocked_scales = x_scales.new_zeros(total_M_padded, scale_cols) start_row_after_padding_list = [0] group_start_idx = 0 for i, group_end_idx in enumerate(group_offs.tolist()): @@ -56,8 +56,7 @@ def torch_to_blocked_2d_M_groups( group_scales_blocked = to_blocked(group_scales) # Calculate the start row after padding - scaling_groups_per_row = K // block_size - rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row + rows_for_group = group_scales_blocked.numel() // scale_cols new_start_row = prev_start_row_after_padding + rows_for_group start_row_after_padding_list.append(new_start_row) @@ -67,7 +66,7 @@ def torch_to_blocked_2d_M_groups( prev_start_row_after_padding : prev_start_row_after_padding + group_rows_padded, :, - ] = group_scales_blocked.reshape(-1, K // block_size) + ] = group_scales_blocked.reshape(-1, scale_cols) # Update next group start index group_start_idx = group_end_idx diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index d96a8b48af..78bfd48ab7 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -68,7 +68,6 @@ def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor: # Rearrange the blocks blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - return rearranged.flatten()