Skip to content

Commit 462dc71

Browse files
coconutrubenmarkc-614
authored andcommitted
[inductor][ez] fixup scaled_mm (pytorch#159948)
Summary: This reverts the part of pytorch#159383 for scaled_mm where now, like before, we pass through the normal input_nodes (not the triton_input_nodes) to select_algorithm - pytorch#159383 refactored how kwargs are retrieved - it introduced this notion of KernelInputs that wrap input_nodes - scaled_mm uses unsqueezed input nodes for triton to retrieve params - the issue: it uses a squeezed (regular) bias for select_algorithm instead This fixes that by passing the original input nodes rather than the triton input nodes. Test Plan: ``` buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_False (caffe2.test.inductor.test_fp8.TestFP8Lowering)' buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_True (caffe2.test.inductor.test_fp8.TestFP8Lowering)' ``` This set of tests was failing, and is passing now Side note: these tests were failing I believe because the unsqueezed bias made the ATEN choice no longer eligible, and there is some minor numerical discrepancy between ATEN and Triton for this. I'm not sure the test should be written like that, as we're implicitly relying on ATEN being the choice here. Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D79717654](https://our.internmc.facebook.com/intern/diff/D79717654) Pull Request resolved: pytorch#159948 Approved by: https://github.com/izaitsevfb, https://github.com/eellison
1 parent 77f7018 commit 462dc71

File tree

1 file changed

+1
-3
lines changed
  • torch/_inductor/kernel

1 file changed

+1
-3
lines changed

torch/_inductor/kernel/mm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,9 +1259,7 @@ def tuned_scaled_mm(
12591259
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
12601260
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
12611261

1262-
return autotune_select_algorithm(
1263-
"scaled_mm", choices, kernel_inputs.nodes(), layout
1264-
)
1262+
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
12651263

12661264

12671265
@functools.cache

0 commit comments

Comments
 (0)