-
Notifications
You must be signed in to change notification settings - Fork 379
Description
Hey, team, AO provides awesome FP8 support with torch compile to get speed and memory improvement, however since torch compile is not always easily applicable for some models such as MoE HF implementation, which requires some changes to make like in the GPT-fast code (also this changes will make the model bother on some memory issue when implementing in that way for prefill stage since it's better for fast decoding)
Currently, AO does not directly offer the fusion of the quant and dequant part into the matmul (scaled_mm) so have to rely on torch compile to fuse and gain the speedup (otherwise will be affected by the quant/dequant stage see #685, #704 and here.
And Transformer Engine has the custom kernel to do so I feel like a small missing piece that could be added is the fused version of this kernel (either triton or torch ops with cuda implementation) so we can get the speed up from AO fp8linear even without using the torch compile. This will be both beneficial for training and inference. Thank you!