Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Oct 2, 2025

Summary

  • Bench script fix: need to synchronize to get more accurate results, otherwise it seems to only measure dispatch time.
  • to_mxfp8_a2a_dequant change: set scaling type to RCEIL for extra speedup with hardware accelerated fp32->e8m0 casting instruction

New results:

input_shape         num_splits    bf16_ms    mxfp8_ms
----------------  ------------  ---------  ----------
(16, 8192, 5120)             4    7.24235     7.63828

default/bf16 fwd+bwd trace:

Screenshot 2025-10-01 at 8 14 00 PM

mxfp8 fwd + bwd trace:

Screenshot 2025-10-01 at 8 14 03 PM

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3114

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Oct 2, 2025
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 2, 2025
@danielvegamyhre danielvegamyhre added mx moe and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Oct 2, 2025
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 2, 2025
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Oct 2, 2025

Investigating why the quant/dequant kernels are so slow, I verified that our cast_bench.py script is hitting 5.6 TB/s+ for dim0 cast torch.compile kernel, then ran the inductor codgen perf script:

(torch) [[email protected] ~/ao (benchfix)]$ python /tmp/torchinductor_danvm/f7/cf7ddzzvhsvznkcpjzyno5vz7thqhm6xwkyl4bumbbobecmh2lg3.py
0.000703
Peak GPU memory usage 4056.474 MB

Seems like inductor is (1) generating 2 kernels for to_mx (I think it should only generate one normally?), and (2) getting lower mem bw utilization than the version we are benchmarking.

(Note I commented out the alltoall calls and everything except the initial quantization kernels in the inductor codegen, so that i could easily bench the quantization kernels it generated on a single device: https://www.internalfb.com/phabricator/paste/view/P1973241868)

@eellison @bdhirsh would you mind helping take a look at a why the inductor codegen for to_mx function is different/worse here, than when compiled and benchmarked in isolation? Context: The autograd func i'm benchmarking does "mxfp8 quantization on inputs -> all_to_all_single -> dequantize outputs" and for some reason the quant/dequant kernels are slower than they are when the to_mx(..) func is compiled + benched in isolation.

Tlparse

https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpeMX52G/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Repro

  1. Using b200 devgpu, clone latest torchao, check out this PR
  2. Run benchmark for mxfp8 dim0 quantization compiled in isolation:
    • TORCH_LOGS_OUT="log.txt" TORCH_LOGS="output_code" CUDA_VISIBLE_DEVICES=4 python benchmarks/mx_formats/cast_bench.py --mode dim0_mxfp8_floor --- see 5.6 TB/s + and single fused kernel
  3. Run benchmark for this autograd func, which does "mxfp8 quantization -> a2a -> dequant" then view and bench inductor codegen:
    • rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCH_TRACE=/tmp/trace3 TORCH_LOGS="output_code" TORCH_LOGS_OUT="log.txt" torchrun --nproc-per-node=4 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py

@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] fix mxfp8 a2a bench script [mxfp8 moe training] fix mxfp8 a2a bench script; set mxfp8 a2a scaling type to RCeIL Oct 2, 2025
input,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
scaling_mode=ScaleCalculationMode.RCEIL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: put scaling mode in the docblock

@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] fix mxfp8 a2a bench script; set mxfp8 a2a scaling type to RCeIL [mxfp8 moe training] fix mxfp8 a2a bench script; set mxfp8 a2a scaling type to RCEIL Oct 3, 2025
@danielvegamyhre danielvegamyhre merged commit cd21d0e into main Oct 3, 2025
12 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants