You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
This reverts the part of pytorch#159383 for scaled_mm where now, like before,
we pass through the normal input_nodes (not the triton_input_nodes)
to select_algorithm
- pytorch#159383 refactored how kwargs are retrieved
- it introduced this notion of KernelInputs that wrap input_nodes
- scaled_mm uses unsqueezed input nodes for triton to retrieve params
- the issue: it uses a squeezed (regular) bias for select_algorithm
instead
This fixes that by passing the original input nodes rather
than the triton input nodes.
Test Plan:
```
buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_False (caffe2.test.inductor.test_fp8.TestFP8Lowering)'
buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_True (caffe2.test.inductor.test_fp8.TestFP8Lowering)'
```
This set of tests was failing, and is passing now
Side note: these tests were failing I believe because the unsqueezed
bias made the ATEN choice no longer eligible, and there is some minor
numerical discrepancy between ATEN and Triton for this. I'm not sure
the test should be written like that, as we're implicitly relying on
ATEN being the choice here.
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D79717654](https://our.internmc.facebook.com/intern/diff/D79717654)
Pull Request resolved: pytorch#159948
Approved by: https://github.com/izaitsevfb, https://github.com/eellison
0 commit comments