-
Notifications
You must be signed in to change notification settings - Fork 369
[mxfp8 moe training] fix mxfp8 a2a bench script; set mxfp8 a2a scaling type to RCEIL #3114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3114
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Investigating why the quant/dequant kernels are so slow, I verified that our cast_bench.py script is hitting 5.6 TB/s+ for dim0 cast torch.compile kernel, then ran the inductor codgen perf script: Seems like inductor is (1) generating 2 kernels for (Note I commented out the alltoall calls and everything except the initial quantization kernels in the inductor codegen, so that i could easily bench the quantization kernels it generated on a single device: https://www.internalfb.com/phabricator/paste/view/P1973241868) @eellison @bdhirsh would you mind helping take a look at a why the inductor codegen for TlparseRepro
|
5fad182 to
b74fb31
Compare
| input, | ||
| elem_dtype=torch.float8_e4m3fn, | ||
| block_size=block_size, | ||
| scaling_mode=ScaleCalculationMode.RCEIL, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: put scaling mode in the docblock
b74fb31 to
acee0cd
Compare
Summary
to_mxfp8_a2a_dequantchange: set scaling type to RCEIL for extra speedup with hardware accelerated fp32->e8m0 casting instructionNew results:
default/bf16 fwd+bwd trace:
mxfp8 fwd + bwd trace: