Skip to content

Commit ccf6426

Browse files
PaulZhang12timocafe
authored andcommitted
[Inductor] Fix consolidating _scaled_mm into mm template TMA error (pytorch#150686)
Summary: The previous diff broke a few tests that didn't run on internal or GH CI: T220169086, this fixes that issue. The {% if } block is only supposed to support autotuned parameters (constexpr), and should not be used for locals based on other examples. Test Plan: buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_tensorwise_scaling_bfloat16_shape_16,32,32_has_bias_False_use_fast_accum_True_persistent_matmul_True (caffe2.test.inductor.test_fp8.TestFP8Lowering)' Reviewed By: NikhilAPatel Differential Revision: D72460516 Pull Request resolved: pytorch#150686 Approved by: https://github.com/eellison, https://github.com/NikhilAPatel
1 parent a7d5e86 commit ccf6426

File tree

1 file changed

+35
-36
lines changed
  • torch/_inductor/kernel

1 file changed

+35
-36
lines changed

torch/_inductor/kernel/mm.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -312,18 +312,18 @@
312312
allow_tf32=ALLOW_TF32,
313313
)
314314
315-
{% if ki == k_tiles - 1 %}
316-
# rematerialize rm and rn to save registers
317-
rcm = rm + tl.arange(0, BLOCK_M)
318-
rcn = rn + tl.arange(0, BLOCK_N)
319-
idx_m = rcm[:, None]
320-
idx_n = rcn[None, :]
321-
mask = (idx_m < M) & (idx_n < N)
322-
323-
# inductor generates a suffix
324-
{{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}
325-
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
326-
{% endif %}
315+
if ki == k_tiles - 1:
316+
# rematerialize rm and rn to save registers
317+
rcm = rm + tl.arange(0, BLOCK_M)
318+
rcn = rn + tl.arange(0, BLOCK_N)
319+
idx_m = rcm[:, None]
320+
idx_n = rcn[None, :]
321+
mask = (idx_m < M) & (idx_n < N)
322+
323+
# inductor generates a suffix
324+
{{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}
325+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
326+
327327
""",
328328
)
329329

@@ -467,31 +467,30 @@ def apply_scaling(
467467
else:
468468
accumulator += tl.dot(a, b.T)
469469
470-
{% if ki == k_tiles - 1 %}
471-
# Apply inverse scaling
472-
offs_cm = offs_am + tl.arange(0, BLOCK_M)
473-
offs_cn = offs_bn + tl.arange(0, BLOCK_N)
474-
# Apply scaling
475-
accumulator = apply_scaling(
476-
accumulator,
477-
a_scale,
478-
b_scale,
479-
SCALING_ROWWISE,
480-
offs_cm,
481-
offs_cn,
482-
M,
483-
N,
484-
stride_a_scale_m,
485-
stride_b_scale_n,
486-
)
470+
if ki == k_tiles - 1:
471+
# Apply inverse scaling
472+
offs_cm = offs_am + tl.arange(0, BLOCK_M)
473+
offs_cn = offs_bn + tl.arange(0, BLOCK_N)
474+
# Apply scaling
475+
accumulator = apply_scaling(
476+
accumulator,
477+
a_scale,
478+
b_scale,
479+
SCALING_ROWWISE,
480+
offs_cm,
481+
offs_cn,
482+
M,
483+
N,
484+
stride_a_scale_m,
485+
stride_b_scale_n,
486+
)
487487
488-
idx_m = offs_cm[:, None]
489-
idx_n = offs_cn[None, :]
490-
mask = (idx_m < M) & (idx_n < N)
491-
# inductor generates a suffix
492-
{{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}}
493-
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
494-
{% endif %}
488+
idx_m = offs_cm[:, None]
489+
idx_n = offs_cn[None, :]
490+
mask = (idx_m < M) & (idx_n < N)
491+
# inductor generates a suffix
492+
{{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}}
493+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
495494
"""
496495

497496

0 commit comments

Comments
 (0)