Skip to content

Commit a73f248

Browse files
oniononion36markc-614
authored andcommitted
[AMD] Fix AMD User Defined Kernel Autotune (pytorch#160671)
Summary: AMD specific kwargs need to be removed from the guard, otherwise a keyerror will be raised when executing the kernel. Test Plan: ``` buck2 run mode/opt-amd-gpu -m rocm641 -c fbcode.split-dwarf=true -c fbcode.use_link_groups=true -c fbcode.enable_gpu_sections=true //hpc/new/models/feed/benchmark:feed_lower_benchmark -- --load=manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/894698382/0/gpu_lowering/new_input8 --skip-eager --skip-flop-estimation --sync-mode=0 --lower-backend=AOT_INDUCTOR ``` can succeed after this change. Rollback Plan: Differential Revision: D80285441 Pull Request resolved: pytorch#160671 Approved by: https://github.com/muchulee8
1 parent 10f07af commit a73f248

File tree

3 files changed

+103
-26
lines changed

3 files changed

+103
-26
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
MACOS_VERSION,
6161
MI300_ARCH,
6262
parametrize,
63+
runOnRocm,
6364
skipIfMPS,
6465
skipIfRocm,
6566
skipIfRocmArch,
@@ -6416,6 +6417,43 @@ def forward(self, x):
64166417
rtol=1e-3,
64176418
)
64186419

6420+
@runOnRocm
6421+
def test_rocm_triton_autotuning(self):
6422+
if self.device != GPU_TYPE:
6423+
raise unittest.SkipTest("requires GPU")
6424+
6425+
class Model(torch.nn.Module):
6426+
def forward(self, x, y, m):
6427+
_M, K = x.shape
6428+
K, N = y.shape
6429+
M = torch.abs(m)
6430+
out = torch.empty((_M, N), device=x.device, dtype=torch.float32)
6431+
grid = lambda META: ( # noqa: E731
6432+
triton.cdiv(
6433+
4096 * 2046, META["BLOCK_SIZE_M"] * META["BLOCK_SIZE_N"]
6434+
),
6435+
)
6436+
strange_config_matmul_kernel[grid](
6437+
x,
6438+
y,
6439+
out,
6440+
M,
6441+
N,
6442+
K,
6443+
)
6444+
return out
6445+
6446+
x = torch.randn(4096, 1024, device=self.device)
6447+
y = torch.randn(1024, 2048, device=self.device)
6448+
m = torch.tensor([4096], dtype=torch.int32, device=self.device)
6449+
6450+
with config.patch("triton.autotune_with_sample_inputs", True):
6451+
# The tuned best config on XPU is different with CUDA.
6452+
grid_0 = 32736 if GPU_TYPE == "xpu" else 1023
6453+
self.code_check_count(
6454+
Model(), (x, y, m), f"uint32_t grid_0 = {grid_0}L;", 1
6455+
)
6456+
64196457
@skipIfRocm # RoCM does not support the config block size in test suite.
64206458
def test_triton_autotuning(self):
64216459
if self.device != GPU_TYPE:

torch/_inductor/codegen/wrapper.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,18 @@ def writeline(line: str, example_grid: Optional[str] = None):
228228
key=lambda x: len(x[1].kwargs),
229229
reverse=True,
230230
):
231+
guardslist = []
231232
if c.kwargs:
232-
guards = [
233-
f"meta['{name}'] == {val}" for name, val in c.kwargs.items()
234-
]
235-
guards = " and ".join(guards)
233+
# Remove AMD specific kwargs.
234+
for kwarg in c.kwargs:
235+
if kwarg not in [
236+
"matrix_instr_nonkdim",
237+
"waves_per_eu",
238+
"kpack",
239+
]:
240+
guardslist.append(f"meta['{kwarg}'] == {c.kwargs[kwarg]}")
241+
if guardslist:
242+
guards = " and ".join(guardslist)
236243
else:
237244
guards = "True" # for configs with empty kwargs
238245
grid, example_grid = determine_grid(grid, example_grid)

torch/testing/_internal/triton_utils.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,59 @@
1515
import triton
1616
from triton import language as tl
1717

18+
import torch
19+
20+
def _get_strange_configs() -> list[triton.Config]:
21+
if torch.version.hip:
22+
configs = [
23+
triton.Config(
24+
{
25+
"BLOCK_SIZE_M": 16,
26+
"BLOCK_SIZE_N": 16,
27+
"matrix_instr_nonkdim": 16,
28+
"waves_per_eu": 3,
29+
"kpack": 2,
30+
},
31+
num_stages=4,
32+
num_warps=4,
33+
),
34+
triton.Config(
35+
{
36+
"BLOCK_SIZE_M": 128,
37+
"BLOCK_SIZE_N": 64,
38+
"matrix_instr_nonkdim": 16,
39+
"waves_per_eu": 3,
40+
"kpack": 2,
41+
},
42+
num_stages=4,
43+
num_warps=4,
44+
),
45+
]
46+
else:
47+
configs = [
48+
triton.Config(
49+
{
50+
"BLOCK_SIZE_M": 16,
51+
"BLOCK_SIZE_N": 16,
52+
"BLOCK_SIZE_K": 16,
53+
"GROUP_SIZE_M": 4,
54+
},
55+
num_stages=4,
56+
num_warps=4,
57+
),
58+
triton.Config(
59+
{
60+
"BLOCK_SIZE_M": 128,
61+
"BLOCK_SIZE_N": 64,
62+
"BLOCK_SIZE_K": 32,
63+
"GROUP_SIZE_M": 8,
64+
},
65+
num_stages=4,
66+
num_warps=4,
67+
),
68+
]
69+
return configs
70+
1871
# Define here so that multiple tests can take advantage of it
1972
@triton.jit
2073
def add_kernel(
@@ -786,28 +839,7 @@ def add_kernel_out_of_order_fn2(
786839
tl.store(out_ptr + offsets, output, mask=mask)
787840

788841
@triton.autotune(
789-
configs=[
790-
triton.Config(
791-
{
792-
"BLOCK_SIZE_M": 16,
793-
"BLOCK_SIZE_N": 16,
794-
"BLOCK_SIZE_K": 16,
795-
"GROUP_SIZE_M": 4,
796-
},
797-
num_stages=4,
798-
num_warps=4,
799-
),
800-
triton.Config(
801-
{
802-
"BLOCK_SIZE_M": 128,
803-
"BLOCK_SIZE_N": 64,
804-
"BLOCK_SIZE_K": 32,
805-
"GROUP_SIZE_M": 8,
806-
},
807-
num_stages=4,
808-
num_warps=4,
809-
),
810-
],
842+
configs=_get_strange_configs(),
811843
key=["M_ptr", "N", "K"],
812844
)
813845
@triton.jit

0 commit comments

Comments
 (0)