Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions benchmarks/float8/bench_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def run(

# Run bf16 torch._grouped_mm baseline.
A = torch.randn(M, K, device=device, dtype=dtype)
B = torch.randn(E, K, N, device=device, dtype=dtype)
B = torch.randn(E, N, K, device=device, dtype=dtype)
offs = generate_jagged_offs(E, M)
print(f"offs: {offs}")
ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks(
Expand All @@ -73,7 +73,7 @@ def run(
use_gpu_kernel_time,
torch._grouped_mm,
A,
B,
B.transpose(-2, -1),
offs,
)
print(
Expand All @@ -84,12 +84,7 @@ def run(

# Run scaled_grouped_mm.
A_hp = torch.randn(M, K, device=device)
B_hp_t = (
torch.randn(E, K, N, device=device)
.transpose(-2, -1)
.contiguous()
.transpose(-2, -1)
)
B_hp_t = torch.randn(E, N, K, device=device).transpose(-2, -1)

if recipe == "rowwise":
# TODO: add e5m2
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_name_to_moe_shapes_iter(
N: Optional[int] = None,
E: Optional[int] = None,
):
M = 8192 if M is None else M
M = 16640 if M is None else M
if shape_gen_name == "llama4_17bx16e":
# num_experts=16, dim=5120
names_to_shapes = {
Expand All @@ -232,8 +232,8 @@ def get_name_to_moe_shapes_iter(
# num_experts=128, dim=5120
names_to_shapes = {
# M, K, N, E
"moe.experts.w1": (M, 5120, 8192, 128),
"moe.experts.w2": (M, 8192, 5120, 128),
"moe.experts.w1": (M, 5120, 4 * 5120, 128),
"moe.experts.w2": (M, 4 * 5120, 5120, 128),
}
return names_to_shapes.items()
elif shape_gen_name == "custom":
Expand Down
15 changes: 9 additions & 6 deletions benchmarks/prototype/moe_training/benchmark_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from triton.testing import do_bench

from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
triton_fp8_per_group_colwise_scales,
triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.utils import (
torch_to_float8_per_group_colwise,
Expand Down Expand Up @@ -49,8 +49,8 @@ class Experiment:


def get_configs() -> List[ExperimentConfig]:
input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
n_groups_list = [4, 8, 16]
input_shapes = [(16640, 5120)] # (Mg, K)
n_groups_list = [16, 128]
high_precision_dtypes = [torch.bfloat16]
configs = []
for input_shape, n_groups, high_precision_dtype in itertools.product(
Expand Down Expand Up @@ -114,13 +114,13 @@ def run_torch(
def run_triton(
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
):
_ = triton_fp8_row_major_jagged_rowwise_scales(
_ = triton_fp8_per_group_rowwise_scales(
input_row_major,
offs,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
_ = triton_fp8_col_major_jagged_colwise_scales(
_ = triton_fp8_per_group_colwise_scales(
input_col_major,
offs,
output_dtype=torch.float8_e4m3fn,
Expand All @@ -129,6 +129,7 @@ def run_triton(

# bench torch
compiled_run_torch = torch.compile(run_torch)
warmup(compiled_run_torch, input_row_major, input_col_major, offs)
torch_time_us = benchmark_cuda_function_in_microseconds(
compiled_run_torch, input_row_major, input_col_major, offs
)
Expand All @@ -152,6 +153,7 @@ def print_results(experiments: List[Experiment]):
"high_precision_dtype",
"torch_time_us",
"triton_time_us",
"triton_speedup",
]
rows = []
for experiment in experiments:
Expand All @@ -165,6 +167,7 @@ def print_results(experiments: List[Experiment]):
experiment.config.high_precision_dtype,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
]
)
print(tabulate(rows, headers=headers))
Expand Down
32 changes: 21 additions & 11 deletions benchmarks/prototype/moe_training/benchmark_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@
"CUDA not available or compute capability < 8.9", allow_module_level=True
)

from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
from torchao.prototype.moe_training.conversion_utils import (
MoEScalingType,
MoETrainingConfig,
)
from torchao.quantization.quant_api import quantize_

# this test requires torchtitan
# this benchmark requires torchtitan
try:
from torchtitan.experiments.llama4.infra.expert_parallel import (
from torchtitan.distributed.expert_parallel import (
set_token_group_alignment_size_m,
)
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
from torchtitan.experiments.llama4.model.moe import MoE
from torchtitan.models.moe import MoE, MoEArgs
except ImportError:
pytest.skip(
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
Expand All @@ -54,16 +56,15 @@ def bench_moe_float8_training_fsdp(enable_profile=False):

# define model args
target_fqns = ["experts"]
model_args = TransformerModelArgs(
moe_enabled=True,
model_args = MoEArgs(
num_experts=16,
dim=5120,
)
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
dim, hidden_dim = 5120, 4 * 5120
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

Expand All @@ -82,20 +83,27 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model
config = MoETrainingConfig()
config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# FSDP2
fully_shard(model)
fully_shard(ref_model)

# inputs (llama4 shapes)
batch, seq, dim = 1, 8192, 5120
batch, seq = 1, 8192
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
x = ref_x.detach().clone().requires_grad_(True)

def warmup(model, input):
for _ in range(3):
out = model(input)
loss = F.mse_loss(out, torch.ones_like(out))
loss.backward()
torch.cuda.synchronize()

def bench_fn_microseconds(model, input):
labels = torch.ones_like(input)
times = []
Expand Down Expand Up @@ -142,6 +150,7 @@ def profile_fn(model, input, profile_name="profile"):
model = torch.compile(model, fullgraph=False)

print("Benchmarking MoE with FSDP2 using bf16 training")
warmup(ref_model, ref_x)
bf16_us = bench_fn_microseconds(ref_model, ref_x)
print(f"bf16 time: {bf16_us} us")
if enable_profile:
Expand All @@ -152,6 +161,7 @@ def profile_fn(model, input, profile_name="profile"):
set_token_group_alignment_size_m(16)

print("Benchmarking MoE with FSDP2 using fp8 rowwise training")
warmup(model, x)
fp8_us = bench_fn_microseconds(model, x)
print(f"fp8 time: {fp8_us} us")
if enable_profile:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from triton.testing import do_bench

from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
triton_fp8_per_group_colwise_scales,
triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.utils import (
torch_to_float8_per_group_colwise,
Expand Down Expand Up @@ -114,13 +114,13 @@ def run_torch(
def run_triton(
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
):
_ = triton_fp8_row_major_jagged_rowwise_scales(
_ = triton_fp8_per_group_rowwise_scales(
input_row_major,
offs,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
_ = triton_fp8_col_major_jagged_colwise_scales(
_ = triton_fp8_per_group_colwise_scales(
input_col_major,
offs,
output_dtype=torch.float8_e4m3fn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ class Experiment:


def get_configs() -> List[ExperimentConfig]:
# Llama4 and DeepSeekV3 shapes
input_shapes = [(8, 4096, 1024), (16, 5120 * 4, 5120)]
# Llama4 shapes
input_shapes = [
(16, 8192, 5120), # w1, w3
(16, 5120, 8192), # w2
]
high_precision_dtypes = [torch.bfloat16]
configs = []
for input_shape, high_precision_dtype in itertools.product(
Expand Down Expand Up @@ -117,6 +120,7 @@ def print_results(experiments: List[Experiment]):
"input_shape",
"torch_time_us",
"triton_time_us",
"triton_speedup",
]
rows = []
for experiment in experiments:
Expand All @@ -126,6 +130,7 @@ def print_results(experiments: List[Experiment]):
input_shape,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
]
)
print(tabulate(rows, headers=headers))
Expand Down
33 changes: 21 additions & 12 deletions test/prototype/moe_training/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@

# this test requires torchtitan
try:
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
from torchtitan.experiments.llama4.model.moe import MoE
from torchtitan.distributed.expert_parallel import (
set_token_group_alignment_size_m,
)
from torchtitan.models.moe import MoE, MoEArgs
except ImportError:
pytest.skip(
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
Expand All @@ -49,18 +51,20 @@ def test_moe_float8_training_fsdp():
# setup distributed for fsdp
setup_distributed()

# token group aligment size must be 16 for fp8
set_token_group_alignment_size_m(16)

# define model args
target_fqns = ["experts"]
model_args = TransformerModelArgs(
moe_enabled=True,
model_args = MoEArgs(
num_experts=8,
dim=256,
)
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
dim, hidden_dim = 5120, 4 * 5120
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

Expand Down Expand Up @@ -93,7 +97,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
fully_shard(ref_model)

# inputs
batch, seq, dim = 8, 2048, 256
batch, seq = 8, 2048
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
Expand All @@ -105,7 +109,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:

# validate output
out_sqnr = compute_error(out, ref_out)
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
min_out_sqnr = 29.0
assert out_sqnr.item() >= min_out_sqnr, (
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
)

# compute loss
labels = torch.ones_like(ref_out)
Expand All @@ -118,15 +125,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:

# validate input gradient
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
assert input_grad_sqnr.item() >= 30.0, (
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
min_input_grad_sqnr = 29.0
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
)

# validate param gradients
min_param_grad_sqnr = 23.0
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
param_grad_sqnr = compute_error(param1.grad, param2.grad)
assert param_grad_sqnr.item() >= 25.0, (
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
)

dist.destroy_process_group()
Expand Down
Loading
Loading