-
Notifications
You must be signed in to change notification settings - Fork 69
Cherry-pick for Inductor Autotune refactor #2392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/2.7
Are you sure you want to change the base?
Conversation
…pytorch#147452) This change was reverted in pytorch#147388 for regressing an internal workload. I have removed the additional ir.device_type calls in mm_scaled and unpack_mixed_mm.py which could be contributing to the additional compile time. Pull Request resolved: pytorch#147452 Approved by: https://github.com/jansel (cherry picked from commit 32299e5)
…n mm and addmm (pytorch#150587) Summary: This PR introduces additional autotuning configurations for the persistent+TMA version of Triton `mm` and `addmm` operations. The new configurations are as follows: * `(128, 128, 64, 5, 8)` * `(256, 128, 64, 4, 8)` * `(128, 128, 64, 5, 4)` These configurations were selected based on exhaustive autotuning performed on commonly used shapes from an internal foundational model. While these new configs are generally more performant across the board, we see notable gains a few specific cases: * In scenarios where `n >> m, k`, the configurations `(128, 128, 64, 5, 8)` and `(256, 128, 64, 4, 8)` tend to produce an additional 5-10% speedup over the aten baseline compared to the original configurations. * Similarly, the configuration `(128, 128, 64, 5, 4)` yields approximately an 8% improvement in scenarios where k >> m, n. These enhancements are expected to provide performance benefits across diverse use cases, particularly when compared to the original set of configurations. Test Plan: contbuild & OSS CI Reviewers: paulzhan Pull Request resolved: pytorch#150587 Approved by: https://github.com/PaulZhang12, https://github.com/drisspg, https://github.com/eellison (cherry picked from commit 5acc3e2)
This PR primarily unifies the flex attention config logic with the GEMM/Conv config approach pytorch#147452 this will make it much easier to handle optimisation pathways for particular triton backends. This PR also introduces: 1. Introduces an exhaustive tuning mode for flex attention via TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE="EXHAUSTIVE" to allow for wide scale benchmarking for perf investigation use cases. 3. Updates configs for ROCm flex autotune path providing perf optimisations AMD perf numbers on score mod benchmark (default inputs) flex_attn | mode | Speedup (Avg) | Speedup (Max) -- | -- | -- | -- fwd | autotune before PR | 2.608 | 20.56 fwd | autotune after PR | 2.862 | 22 fwd | exhaustive_autotune | 2.943 | 22.471 bwd | autotune before PR | 2.196 | 9.831 bwd | autotune after PR | 2.423 | 11.331 bwd | exhaustive_autotune | 2.566 | 13.87 Pull Request resolved: pytorch#156307 Approved by: https://github.com/drisspg, https://github.com/jansel (cherry picked from commit 03023f1)
Jenkins build for 820aa47fd4b2cf232f11c42ee9f8ae84b9ef63af commit finished as FAILURE |
Some model results: After tuning cherry pick | TORCHINDUCTOR_FLEX_SEARCH_SPACE="EXHAUSTIVE" | score_mod bench
|
Jenkins build for 77b19c13455ae93e0ffd70bff7d33b261df438ac commit finished as FAILURE |
Required for flex attention and gemm improvements