|
312 | 312 | allow_tf32=ALLOW_TF32,
|
313 | 313 | )
|
314 | 314 |
|
315 |
| - {% if ki == k_tiles - 1 %} |
316 |
| - # rematerialize rm and rn to save registers |
317 |
| - rcm = rm + tl.arange(0, BLOCK_M) |
318 |
| - rcn = rn + tl.arange(0, BLOCK_N) |
319 |
| - idx_m = rcm[:, None] |
320 |
| - idx_n = rcn[None, :] |
321 |
| - mask = (idx_m < M) & (idx_n < N) |
322 |
| -
|
323 |
| - # inductor generates a suffix |
324 |
| - {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} |
325 |
| - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) |
326 |
| - {% endif %} |
| 315 | + if ki == k_tiles - 1: |
| 316 | + # rematerialize rm and rn to save registers |
| 317 | + rcm = rm + tl.arange(0, BLOCK_M) |
| 318 | + rcn = rn + tl.arange(0, BLOCK_N) |
| 319 | + idx_m = rcm[:, None] |
| 320 | + idx_n = rcn[None, :] |
| 321 | + mask = (idx_m < M) & (idx_n < N) |
| 322 | +
|
| 323 | + # inductor generates a suffix |
| 324 | + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} |
| 325 | + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) |
| 326 | +
|
327 | 327 | """,
|
328 | 328 | )
|
329 | 329 |
|
@@ -467,31 +467,30 @@ def apply_scaling(
|
467 | 467 | else:
|
468 | 468 | accumulator += tl.dot(a, b.T)
|
469 | 469 |
|
470 |
| - {% if ki == k_tiles - 1 %} |
471 |
| - # Apply inverse scaling |
472 |
| - offs_cm = offs_am + tl.arange(0, BLOCK_M) |
473 |
| - offs_cn = offs_bn + tl.arange(0, BLOCK_N) |
474 |
| - # Apply scaling |
475 |
| - accumulator = apply_scaling( |
476 |
| - accumulator, |
477 |
| - a_scale, |
478 |
| - b_scale, |
479 |
| - SCALING_ROWWISE, |
480 |
| - offs_cm, |
481 |
| - offs_cn, |
482 |
| - M, |
483 |
| - N, |
484 |
| - stride_a_scale_m, |
485 |
| - stride_b_scale_n, |
486 |
| - ) |
| 470 | + if ki == k_tiles - 1: |
| 471 | + # Apply inverse scaling |
| 472 | + offs_cm = offs_am + tl.arange(0, BLOCK_M) |
| 473 | + offs_cn = offs_bn + tl.arange(0, BLOCK_N) |
| 474 | + # Apply scaling |
| 475 | + accumulator = apply_scaling( |
| 476 | + accumulator, |
| 477 | + a_scale, |
| 478 | + b_scale, |
| 479 | + SCALING_ROWWISE, |
| 480 | + offs_cm, |
| 481 | + offs_cn, |
| 482 | + M, |
| 483 | + N, |
| 484 | + stride_a_scale_m, |
| 485 | + stride_b_scale_n, |
| 486 | + ) |
487 | 487 |
|
488 |
| - idx_m = offs_cm[:, None] |
489 |
| - idx_n = offs_cn[None, :] |
490 |
| - mask = (idx_m < M) & (idx_n < N) |
491 |
| - # inductor generates a suffix |
492 |
| - {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} |
493 |
| - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
494 |
| - {% endif %} |
| 488 | + idx_m = offs_cm[:, None] |
| 489 | + idx_n = offs_cn[None, :] |
| 490 | + mask = (idx_m < M) & (idx_n < N) |
| 491 | + # inductor generates a suffix |
| 492 | + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} |
| 493 | + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
495 | 494 | """
|
496 | 495 |
|
497 | 496 |
|
|
0 commit comments