diff --git a/benchmarks/float8/bench_grouped_mm.py b/benchmarks/float8/bench_grouped_mm.py index 5b0bea1822..1bded14c44 100644 --- a/benchmarks/float8/bench_grouped_mm.py +++ b/benchmarks/float8/bench_grouped_mm.py @@ -64,7 +64,7 @@ def run( # Run bf16 torch._grouped_mm baseline. A = torch.randn(M, K, device=device, dtype=dtype) - B = torch.randn(E, K, N, device=device, dtype=dtype) + B = torch.randn(E, N, K, device=device, dtype=dtype) offs = generate_jagged_offs(E, M) print(f"offs: {offs}") ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks( @@ -73,7 +73,7 @@ def run( use_gpu_kernel_time, torch._grouped_mm, A, - B, + B.transpose(-2, -1), offs, ) print( @@ -84,12 +84,7 @@ def run( # Run scaled_grouped_mm. A_hp = torch.randn(M, K, device=device) - B_hp_t = ( - torch.randn(E, K, N, device=device) - .transpose(-2, -1) - .contiguous() - .transpose(-2, -1) - ) + B_hp_t = torch.randn(E, N, K, device=device).transpose(-2, -1) if recipe == "rowwise": # TODO: add e5m2 diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py index d4cdfeef20..744bbcad0d 100644 --- a/benchmarks/float8/utils.py +++ b/benchmarks/float8/utils.py @@ -219,7 +219,7 @@ def get_name_to_moe_shapes_iter( N: Optional[int] = None, E: Optional[int] = None, ): - M = 8192 if M is None else M + M = 16640 if M is None else M if shape_gen_name == "llama4_17bx16e": # num_experts=16, dim=5120 names_to_shapes = { @@ -232,8 +232,8 @@ def get_name_to_moe_shapes_iter( # num_experts=128, dim=5120 names_to_shapes = { # M, K, N, E - "moe.experts.w1": (M, 5120, 8192, 128), - "moe.experts.w2": (M, 8192, 5120, 128), + "moe.experts.w1": (M, 5120, 4 * 5120, 128), + "moe.experts.w2": (M, 4 * 5120, 5120, 128), } return names_to_shapes.items() elif shape_gen_name == "custom": diff --git a/benchmarks/prototype/moe_training/benchmark_kernels.py b/benchmarks/prototype/moe_training/benchmark_kernels.py index d9e79c6cf3..f180bb15ac 100644 --- a/benchmarks/prototype/moe_training/benchmark_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_kernels.py @@ -15,8 +15,8 @@ from triton.testing import do_bench from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, + triton_fp8_per_group_rowwise_scales, ) from torchao.prototype.moe_training.utils import ( torch_to_float8_per_group_colwise, @@ -49,8 +49,8 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: - input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)] - n_groups_list = [4, 8, 16] + input_shapes = [(16640, 5120)] # (Mg, K) + n_groups_list = [16, 128] high_precision_dtypes = [torch.bfloat16] configs = [] for input_shape, n_groups, high_precision_dtype in itertools.product( @@ -114,13 +114,13 @@ def run_torch( def run_triton( input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor ): - _ = triton_fp8_row_major_jagged_rowwise_scales( + _ = triton_fp8_per_group_rowwise_scales( input_row_major, offs, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) - _ = triton_fp8_col_major_jagged_colwise_scales( + _ = triton_fp8_per_group_colwise_scales( input_col_major, offs, output_dtype=torch.float8_e4m3fn, @@ -129,6 +129,7 @@ def run_triton( # bench torch compiled_run_torch = torch.compile(run_torch) + warmup(compiled_run_torch, input_row_major, input_col_major, offs) torch_time_us = benchmark_cuda_function_in_microseconds( compiled_run_torch, input_row_major, input_col_major, offs ) @@ -152,6 +153,7 @@ def print_results(experiments: List[Experiment]): "high_precision_dtype", "torch_time_us", "triton_time_us", + "triton_speedup", ] rows = [] for experiment in experiments: @@ -165,6 +167,7 @@ def print_results(experiments: List[Experiment]): experiment.config.high_precision_dtype, experiment.result.torch_time_us, experiment.result.triton_time_us, + f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x", ] ) print(tabulate(rows, headers=headers)) diff --git a/benchmarks/prototype/moe_training/benchmark_moe_layer.py b/benchmarks/prototype/moe_training/benchmark_moe_layer.py index 549aae5a5e..d18c6dc176 100644 --- a/benchmarks/prototype/moe_training/benchmark_moe_layer.py +++ b/benchmarks/prototype/moe_training/benchmark_moe_layer.py @@ -30,16 +30,18 @@ "CUDA not available or compute capability < 8.9", allow_module_level=True ) -from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) from torchao.quantization.quant_api import quantize_ -# this test requires torchtitan +# this benchmark requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -54,16 +56,15 @@ def bench_moe_float8_training_fsdp(enable_profile=False): # define model args target_fqns = ["experts"] - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=16, - dim=5120, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 4 * 5120 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -82,7 +83,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig() + config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # FSDP2 @@ -90,12 +91,19 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: fully_shard(ref_model) # inputs (llama4 shapes) - batch, seq, dim = 1, 8192, 5120 + batch, seq = 1, 8192 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) x = ref_x.detach().clone().requires_grad_(True) + def warmup(model, input): + for _ in range(3): + out = model(input) + loss = F.mse_loss(out, torch.ones_like(out)) + loss.backward() + torch.cuda.synchronize() + def bench_fn_microseconds(model, input): labels = torch.ones_like(input) times = [] @@ -142,6 +150,7 @@ def profile_fn(model, input, profile_name="profile"): model = torch.compile(model, fullgraph=False) print("Benchmarking MoE with FSDP2 using bf16 training") + warmup(ref_model, ref_x) bf16_us = bench_fn_microseconds(ref_model, ref_x) print(f"bf16 time: {bf16_us} us") if enable_profile: @@ -152,6 +161,7 @@ def profile_fn(model, input, profile_name="profile"): set_token_group_alignment_size_m(16) print("Benchmarking MoE with FSDP2 using fp8 rowwise training") + warmup(model, x) fp8_us = bench_fn_microseconds(model, x) print(f"fp8 time: {fp8_us} us") if enable_profile: diff --git a/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py b/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py index d9e79c6cf3..45c9c7c22b 100644 --- a/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py @@ -15,8 +15,8 @@ from triton.testing import do_bench from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, + triton_fp8_per_group_rowwise_scales, ) from torchao.prototype.moe_training.utils import ( torch_to_float8_per_group_colwise, @@ -114,13 +114,13 @@ def run_torch( def run_triton( input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor ): - _ = triton_fp8_row_major_jagged_rowwise_scales( + _ = triton_fp8_per_group_rowwise_scales( input_row_major, offs, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) - _ = triton_fp8_col_major_jagged_colwise_scales( + _ = triton_fp8_per_group_colwise_scales( input_col_major, offs, output_dtype=torch.float8_e4m3fn, diff --git a/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py b/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py index 0cdb1c4957..66a7c91f53 100644 --- a/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py @@ -46,8 +46,11 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: - # Llama4 and DeepSeekV3 shapes - input_shapes = [(8, 4096, 1024), (16, 5120 * 4, 5120)] + # Llama4 shapes + input_shapes = [ + (16, 8192, 5120), # w1, w3 + (16, 5120, 8192), # w2 + ] high_precision_dtypes = [torch.bfloat16] configs = [] for input_shape, high_precision_dtype in itertools.product( @@ -117,6 +120,7 @@ def print_results(experiments: List[Experiment]): "input_shape", "torch_time_us", "triton_time_us", + "triton_speedup", ] rows = [] for experiment in experiments: @@ -126,6 +130,7 @@ def print_results(experiments: List[Experiment]): input_shape, experiment.result.torch_time_us, experiment.result.triton_time_us, + f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x", ] ) print(tabulate(rows, headers=headers)) diff --git a/test/prototype/moe_training/test_fsdp.py b/test/prototype/moe_training/test_fsdp.py index 69c15e2253..b205675527 100644 --- a/test/prototype/moe_training/test_fsdp.py +++ b/test/prototype/moe_training/test_fsdp.py @@ -35,8 +35,10 @@ # this test requires torchtitan try: - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.distributed.expert_parallel import ( + set_token_group_alignment_size_m, + ) + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -49,18 +51,20 @@ def test_moe_float8_training_fsdp(): # setup distributed for fsdp setup_distributed() + # token group aligment size must be 16 for fp8 + set_token_group_alignment_size_m(16) + # define model args target_fqns = ["experts"] - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 4 * 5120 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -93,7 +97,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: fully_shard(ref_model) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -105,7 +109,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + min_out_sqnr = 29.0 + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) # compute loss labels = torch.ones_like(ref_out) @@ -118,15 +125,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - assert input_grad_sqnr.item() >= 30.0, ( - f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." + min_input_grad_sqnr = 29.0 + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients + min_param_grad_sqnr = 23.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= 25.0, ( - f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." ) dist.destroy_process_group() diff --git a/test/prototype/moe_training/test_fsdp_tp.py b/test/prototype/moe_training/test_fsdp_tp.py index 083d9de1b9..4a7c1356c0 100644 --- a/test/prototype/moe_training/test_fsdp_tp.py +++ b/test/prototype/moe_training/test_fsdp_tp.py @@ -49,14 +49,14 @@ # this test requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, NoParallel, TensorParallel, + set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -74,21 +74,22 @@ def test_moe_float8_training_fsdp_tp(target_fqns: list[str]): assert torch.cuda.is_available() + # token group aligment size must be 16 for fp8 + set_token_group_alignment_size_m(16) + # setup distributed for tp mesh = setup_distributed() # define model args - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, - vocab_size=1024, ) + dim, hidden_dim = 5120, 4 * 5120 init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(1) ref_model.init_weights(init_std, device) @@ -146,7 +147,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: ) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -158,7 +159,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + min_out_sqnr = 30.0 + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) # compute loss labels = torch.ones_like(ref_out) @@ -171,15 +175,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - assert input_grad_sqnr.item() >= 28.0, ( - f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}." + min_input_grad_sqnr = 28.0 + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients + min_param_grad_sqnr = 23.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= 25.0, ( - f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." ) dist.destroy_process_group() diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index b24b61be8c..ea4afa5c90 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -23,8 +23,8 @@ triton_fp8_rowwise_3d_transpose_rhs, ) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, + triton_fp8_per_group_rowwise_scales, ) from torchao.prototype.moe_training.utils import ( _is_column_major, @@ -52,7 +52,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool): target_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=round_scales_to_power_of_2, ) - kernel_fp8_data, kernel_scales = triton_fp8_row_major_jagged_rowwise_scales( + kernel_fp8_data, kernel_scales = triton_fp8_per_group_rowwise_scales( x, colwise_offs, output_dtype=torch.float8_e4m3fn, @@ -80,7 +80,7 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo target_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=round_scales_to_power_of_2, ) - kernel_fp8_data, kernel_scales = triton_fp8_col_major_jagged_colwise_scales( + kernel_fp8_data, kernel_scales = triton_fp8_per_group_colwise_scales( x, rowwise_offs, output_dtype=torch.float8_e4m3fn, diff --git a/test/prototype/moe_training/test_tp.py b/test/prototype/moe_training/test_tp.py index 46ba544791..bf913a69b3 100644 --- a/test/prototype/moe_training/test_tp.py +++ b/test/prototype/moe_training/test_tp.py @@ -49,14 +49,14 @@ # this test requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, NoParallel, TensorParallel, + set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -74,21 +74,22 @@ def test_moe_float8_training_tp(target_fqns: list[str]): assert torch.cuda.is_available() + # token group aligment size must be 16 for fp8 + set_token_group_alignment_size_m(16) + # setup distributed for tp mesh = setup_distributed() # define model args - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, - vocab_size=1024, ) + dim, hidden_dim = 5120, 4 * 5120 init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(1) ref_model.init_weights(init_std, device) @@ -141,7 +142,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: ) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -153,7 +154,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + min_out_sqnr = 29.0 + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) # compute loss labels = torch.ones_like(ref_out) @@ -166,15 +170,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - assert input_grad_sqnr.item() >= 28.0, ( - f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}." + min_input_grad_sqnr = 28.0 + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients + min_param_grad_sqnr = 23.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= 25.0, ( - f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." ) dist.destroy_process_group() @@ -203,7 +209,7 @@ def apply_moe_ep_tp( moe_layer_plan = { # input / output sharding on the seqlen dim # all-gather for input, reduce-scatter for output - "moe": PrepareModuleInputOutput( + "": PrepareModuleInputOutput( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), use_local_input=True, @@ -211,9 +217,9 @@ def apply_moe_ep_tp( desired_output_layouts=(Shard(1),), ), # replicate computation for the router - "moe.router.gate": NoParallel(), + "router.gate": NoParallel(), # input Replicate, output Partial - "moe.shared_expert": TensorParallel(), + "shared_expert": TensorParallel(), } parallelize_module( module=model, diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index d08f218842..98f9fb266a 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -22,11 +22,10 @@ # this test requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -47,16 +46,15 @@ def test_moe_float8_training(target_fqns: list[str], compile: bool): # has the contraction dim be divisible by 16. 16 byte alignment is required # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements. set_token_group_alignment_size_m(16) - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 4 * 5120 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -75,7 +73,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE) + config = MoETrainingConfig() quantize_(model, config=config, filter_fn=moe_module_filter_fn) # validate that only the experts were converted @@ -83,14 +81,13 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: model, target_fqns=target_fqns, ) - if compile: # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it model = torch.compile(model, fullgraph=False) ref_model = torch.compile(ref_model, fullgraph=False) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -124,7 +121,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: ) # validate param gradients - min_param_grad_sqnr = 25.0 + min_param_grad_sqnr = 23.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( @@ -145,18 +142,15 @@ def test_moe_mxfp8_training(target_fqns: list[str]): # Token groups must be divisible by 32 for mxfp8 set_token_group_alignment_size_m(block_size) - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, - multiple_of=block_size, - ffn_dim_multiplier=1.0, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 256, 4 * 256 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -185,7 +179,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: ) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) diff --git a/torchao/prototype/moe_training/kernels/__init__.py b/torchao/prototype/moe_training/kernels/__init__.py index 8fb16579e5..0b88cc08a2 100644 --- a/torchao/prototype/moe_training/kernels/__init__.py +++ b/torchao/prototype/moe_training/kernels/__init__.py @@ -2,8 +2,8 @@ triton_fp8_rowwise_3d_transpose_rhs as triton_fp8_rowwise_3d_transpose_rhs, ) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales, + triton_fp8_per_group_colwise_scales as triton_fp8_per_group_colwise_scales, ) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales, ) diff --git a/torchao/prototype/moe_training/kernels/float8_rowwise.py b/torchao/prototype/moe_training/kernels/float8_rowwise.py index 9d7a7768d4..3449b89336 100644 --- a/torchao/prototype/moe_training/kernels/float8_rowwise.py +++ b/torchao/prototype/moe_training/kernels/float8_rowwise.py @@ -29,7 +29,7 @@ block_sizes_n = [32, 128, 512] # large dim (output_features) block_sizes_k = [32, 128, 512] # small dim (input_features) num_warps = [8] -num_stages = [2, 3] +num_stages = [2, 4] kernel_configs_2D = [ triton.Config( {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}, @@ -42,10 +42,8 @@ for stages in num_stages ] -from torch.library import triton_op, wrap_triton - -@triton_op("torchao::triton_fp8_rowwise_transpose_rhs", mutates_args={}) +@torch.library.custom_op("torchao::triton_fp8_rowwise_transpose_rhs", mutates_args={}) def triton_fp8_rowwise_3d_transpose_rhs( hp_tensor: torch.Tensor, # (E, K, N) output_dtype: torch.dtype = torch.float8_e4m3fn, @@ -80,7 +78,7 @@ def triton_fp8_rowwise_3d_transpose_rhs( ) # compute scales - wrap_triton(_triton_fp8_rowwise_3d_transpose_scales_rhs_kernel)[grid]( + _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel[grid]( hp_tensor, hp_tensor.stride(0), hp_tensor.stride(1), @@ -100,7 +98,7 @@ def triton_fp8_rowwise_3d_transpose_rhs( ) # perform casting - wrap_triton(_triton_fp8_rowwise_3d_transpose_cast_rhs_kernel)[grid]( + _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel[grid]( hp_tensor, hp_tensor.stride(0), hp_tensor.stride(1), @@ -124,6 +122,22 @@ def triton_fp8_rowwise_3d_transpose_rhs( return output_buffer, scales_buffer +@triton_fp8_rowwise_3d_transpose_rhs.register_fake +def _fake_triton_fp8_rowwise_3d_transpose_rhs( + hp_tensor: torch.Tensor, # (E, K, N) + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 3, "input tensor must be 3D" + e, k, n = hp_tensor.shape + output_buffer = torch.empty( + (e, n, k), dtype=output_dtype, device=hp_tensor.device + ).as_strided((e, n, k), (n * k, 1, n)) + + scales_buffer = torch.empty((e, k), dtype=torch.float32, device=hp_tensor.device) + return output_buffer, scales_buffer + + @triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) @triton.jit def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel( diff --git a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py index ff0b11acba..a9b8528975 100644 --- a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py +++ b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py @@ -32,9 +32,9 @@ } block_sizes = [1, 16, 32, 64] -block_sizes_iter = [32, 64, 128, 256] -num_warps = [1, 4] -num_stages = [2, 3] +block_sizes_iter = [64, 128, 256] +num_warps = [4] +num_stages = [3] kernel_configs_2D = [ triton.Config( {"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter}, @@ -47,11 +47,11 @@ for stages in num_stages ] -from torch.library import triton_op, wrap_triton - -@triton_op("torchao::triton_fp8_row_major_jagged_rowwise_scales", mutates_args={}) -def triton_fp8_row_major_jagged_rowwise_scales( +@torch.library.custom_op( + "torchao::triton_fp8_per_group_rowwise_scales", mutates_args={} +) +def triton_fp8_per_group_rowwise_scales( hp_tensor: torch.Tensor, offsets: torch.Tensor, output_dtype: torch.dtype = torch.float8_e4m3fn, @@ -85,7 +85,12 @@ def triton_fp8_row_major_jagged_rowwise_scales( n_groups = offsets.numel() # allocate on-device buffers for output and scales - output_buffer = torch.empty((m, k), dtype=output_dtype, device=hp_tensor.device) + output_buffer = torch.empty( + (m, k), dtype=output_dtype, device=hp_tensor.device + ).as_strided( + (m, k), # shape + (1, m), # stride + ) scales_buffer = torch.empty( (m * n_groups), dtype=torch.float32, device=hp_tensor.device ) @@ -95,7 +100,7 @@ def triton_fp8_row_major_jagged_rowwise_scales( triton.cdiv(m, meta["BLOCK_SIZE"]), offsets.numel(), ) - wrap_triton(_triton_fp8_row_major_jagged_rowwise_scales)[grid]( + _triton_fp8_per_group_rowwise_scales_kernel[grid]( hp_tensor, offsets, output_buffer, @@ -114,7 +119,25 @@ def triton_fp8_row_major_jagged_rowwise_scales( round_scales_to_power_of_2, EPS=EPS, ) - return output_buffer, scales_buffer + return output_buffer.transpose(-2, -1).contiguous().transpose(-2, -1), scales_buffer + + +@triton_fp8_per_group_rowwise_scales.register_fake +def _fake_triton_fp8_per_group_rowwise_scales_kernel( + hp_tensor: torch.Tensor, + offsets: torch.Tensor, + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 2, "input tensor must be 2D" + m, k = hp_tensor.shape + n_groups = offsets.numel() + output = torch.empty_like(hp_tensor, dtype=output_dtype).as_strided( + (m, k), # shape + (k, 1), # stride + ) + scales = torch.empty((m * n_groups), dtype=torch.float32, device=hp_tensor.device) + return output, scales # This kernel is used on grad_output.t() which has shape (K, M), @@ -125,7 +148,7 @@ def triton_fp8_row_major_jagged_rowwise_scales( # to recompile on `token` dim (K, in this case) changes. @triton.autotune(configs=kernel_configs_2D, key=["M"]) @triton.jit -def _triton_fp8_row_major_jagged_rowwise_scales( +def _triton_fp8_per_group_rowwise_scales_kernel( input_ptr, offsets_ptr, out_ptr, @@ -215,8 +238,10 @@ def _triton_fp8_row_major_jagged_rowwise_scales( tl.store(out_ptr + out_offs, fp8_data, mask=block_mask) -@triton_op("torchao::triton_fp8_col_major_jagged_colwise_scales", mutates_args={}) -def triton_fp8_col_major_jagged_colwise_scales( +@torch.library.custom_op( + "torchao::triton_fp8_per_group_colwise_scales", mutates_args={} +) +def triton_fp8_per_group_colwise_scales( hp_tensor: torch.Tensor, offsets: torch.Tensor, output_dtype: torch.dtype = torch.float8_e4m3fn, @@ -263,7 +288,7 @@ def triton_fp8_col_major_jagged_colwise_scales( triton.cdiv(n, meta["BLOCK_SIZE"]), offsets.numel(), ) - wrap_triton(_triton_fp8_col_major_jagged_colwise_scales)[grid]( + _triton_fp8_per_group_colwise_scales_kernel[grid]( hp_tensor, offsets, output_buffer, @@ -285,19 +310,39 @@ def triton_fp8_col_major_jagged_colwise_scales( return output_buffer, scales_buffer +@triton_fp8_per_group_colwise_scales.register_fake +def _fake_triton_fp8_per_group_colwise_scales( + hp_tensor: torch.Tensor, + offsets: torch.Tensor, + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 2, "input tensor must be 2D" + k, n = hp_tensor.shape + n_groups = offsets.numel() + output_buffer = torch.empty_like( + hp_tensor, dtype=output_dtype, device=hp_tensor.device + ).as_strided(hp_tensor.size(), (1, k)) + + scales_buffer = torch.empty( + (n * n_groups), dtype=torch.float32, device=hp_tensor.device + ) + return output_buffer, scales_buffer + + # This kernel is used on `input` which has shape (M, K), # before the calculation `grad_B = grad_output_t @ input`. # The tokens per expert will vary per iteration, so don't want # to recompile on `token` dim (M) changes. @triton.autotune(configs=kernel_configs_2D, key=["K"]) @triton.jit -def _triton_fp8_col_major_jagged_colwise_scales( +def _triton_fp8_per_group_colwise_scales_kernel( input_ptr, offsets_ptr, out_ptr, scales_ptr, + M: int, K: int, - N: int, stride_input_row: int, stride_input_col: int, stride_output_row: int, @@ -332,7 +377,7 @@ def _triton_fp8_col_major_jagged_colwise_scales( + block_col_offs[None, :] * stride_input_col ) block_mask = (block_row_offs[:, None] < group_row_end_idx) & ( - block_col_offs[None, :] < N + block_col_offs[None, :] < K ) data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to( input_dtype @@ -354,8 +399,8 @@ def _triton_fp8_col_major_jagged_colwise_scales( # store colwise scales for each group in contiguous memory: # [group0_col0, group_0_col1, ..., group2_col0, group2_col1] # note: input tensor is in col-major memory layout. - scales_offs = block_col_offs + (N * offset_idx) - scales_mask = tl.arange(0, BLOCK_SIZE) < N + scales_offs = block_col_offs + (K * offset_idx) + scales_mask = tl.arange(0, BLOCK_SIZE) < K tl.store(scales_ptr + scales_offs, scales, mask=scales_mask) # perform float8 conversion for this group @@ -366,7 +411,7 @@ def _triton_fp8_col_major_jagged_colwise_scales( + block_col_offs[None, :] * stride_input_col ) block_mask = (block_row_offs[:, None] < group_row_end_idx) & ( - block_col_offs[None, :] < N + block_col_offs[None, :] < K ) data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to( input_dtype diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 7dc246e251..58d7aa71d8 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -13,8 +13,8 @@ from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated from torchao.prototype.moe_training.conversion_utils import MoEScalingType from torchao.prototype.moe_training.kernels import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, + triton_fp8_per_group_rowwise_scales, triton_fp8_rowwise_3d_transpose_rhs, ) from torchao.prototype.moe_training.utils import ( @@ -48,7 +48,7 @@ def _scaled_grouped_mm( """ # TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging. if scaling_type == MoEScalingType.FP8_ROWWISE: - logger.info("Using fp8 rowwise scaled_grouped_mm") + print("Using fp8 rowwise scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, @@ -56,7 +56,7 @@ def _scaled_grouped_mm( out_dtype, ) elif scaling_type == MoEScalingType.MXFP8: - logger.info("Using mxfp8 scaled_grouped_mm") + print("Using mxfp8 scaled_grouped_mm") block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow? return _MXFP8GroupedMM.apply( A, @@ -144,7 +144,7 @@ def forward( # low precision B tensor instead of the high precision B tensor. # In the backward this is needed for grad_A: grad_output @ B. B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( - B_t, + B_t._data, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) @@ -230,7 +230,7 @@ def backward(ctx, grad_output: torch.Tensor): # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM # needed for grad_B: grad_output_t @ A grad_output_t_fp8_row_major, grad_output_t_scales = ( - triton_fp8_row_major_jagged_rowwise_scales( + triton_fp8_per_group_rowwise_scales( grad_output.transpose(-2, -1), offs, torch.float8_e4m3fn, @@ -238,7 +238,7 @@ def backward(ctx, grad_output: torch.Tensor): ) ) - A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales( + A_fp8_col_major, A_scales = triton_fp8_per_group_colwise_scales( A, offs, torch.float8_e4m3fn, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 1ddd098675..a861aa6533 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -97,9 +97,12 @@ def __torch_function__(cls, func, types, args, kwargs={}): A_is_2d = A.dim() == 2 B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None + other_args = args[2:] if A_is_2d and B_is_3d and has_offs: return _scaled_grouped_mm( - *args, + A, + B, + *other_args, scaling_type=scaling_type, **kwargs, ) @@ -111,16 +114,25 @@ def __torch_function__(cls, func, types, args, kwargs={}): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs={}): - # detach is special case - scaling_type = args[0].scaling_type - if func == torch.ops.aten.detach.default: - return ScaledGroupedMMTensor(args[0]._data, scaling_type) + # unwrap args/kwargs and extract scaling_type + scaling_type = None + + def unwrap(t): + nonlocal scaling_type + if scaling_type is None: + scaling_type = t.scaling_type + else: + assert t.scaling_type == scaling_type + return t._data - # unwrap args/kwargs - unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x args, kwargs = pytree.tree_map_only( ScaledGroupedMMTensor, unwrap, (args, kwargs or {}) ) + assert scaling_type is not None + + # detach is special case + if func == torch.ops.aten.detach.default: + return ScaledGroupedMMTensor(args[0], scaling_type) # perform op out = func(*args, **kwargs)