Skip to content

[Feature Request] Fused fp8 matmul kernel (quant + dequant + matmul) #752

@qingquansong

Description

@qingquansong

Hey, team, AO provides awesome FP8 support with torch compile to get speed and memory improvement, however since torch compile is not always easily applicable for some models such as MoE HF implementation, which requires some changes to make like in the GPT-fast code (also this changes will make the model bother on some memory issue when implementing in that way for prefill stage since it's better for fast decoding)

Currently, AO does not directly offer the fusion of the quant and dequant part into the matmul (scaled_mm) so have to rely on torch compile to fuse and gain the speedup (otherwise will be affected by the quant/dequant stage see #685, #704 and here.

And Transformer Engine has the custom kernel to do so I feel like a small missing piece that could be added is the fused version of this kernel (either triton or torch ops with cuda implementation) so we can get the speed up from AO fp8linear even without using the torch compile. This will be both beneficial for training and inference. Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions