Skip to content

Commit 10c3e6e

Browse files
karthickaipytorchmergebot
authored andcommitted
[inductor][dynamo] Include operator name in size/stride/alignment assertion (pytorch#152353)
Fixes pytorch#151930 This PR updates the `assert_size_stride` and `assert_alignment` functions in [guards.cpp](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp) to accept an optional `op_name` argument and includes it in the error messages. The corresponding type stubs in [guards.pyi](https://github.com/pytorch/pytorch/blob/main/torch/_C/_dynamo/guards.pyi) are updated to match the new function arg. In [inductor/ir.py](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py) extracts the operator name from the FX graph and passes it into the `codegen_size_asserts` and `codegen_alignment_asserts` functions, so that generated assertions in Triton code include the op name for better debugging. Added unit tests inside [test_torchinductor.py](https://github.com/pytorch/pytorch/blob/main/test/inductor/test_torchinductor.py). - Verified both successful and failing assertion cases include the operator name. - Verified that generated Triton code contains the op name inside the asserts. Pull Request resolved: pytorch#152353 Approved by: https://github.com/jansel, https://github.com/shunting314
1 parent cc96feb commit 10c3e6e

File tree

6 files changed

+189
-17
lines changed

6 files changed

+189
-17
lines changed

test/distributed/test_functional_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,13 @@ def run_with_backward():
715715

716716
_, codes = run_and_get_code(run_with_backward)
717717
for code in codes:
718+
assert_keywords = ["assert_size_stride", "assert_alignment"]
719+
filtered_lines = [
720+
line
721+
for line in code.splitlines()
722+
if not any(assert_key in line for assert_key in assert_keywords)
723+
]
724+
code = "\n".join(filtered_lines)
718725
FileCheck().check_count(
719726
"_c10d_functional.all_to_all_single.default", 1, exactly=True
720727
).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ def _test_code_common(
231231
torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
232232
*clone_inputs,
233233
)
234+
assert_keywords = ["assert_size_stride", "assert_alignment"]
235+
filtered_lines = [
236+
line
237+
for line in source_code.splitlines()
238+
if not any(assert_key in line for assert_key in assert_keywords)
239+
]
240+
source_code = "\n".join(filtered_lines)
241+
234242
for op in include_ops:
235243
self.assertIn(op, source_code)
236244
if num_include_ops is not None:

test/inductor/test_torchinductor.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch._dynamo.config as dynamo_config
3131
import torch._inductor.aoti_eager
3232
import torch.nn as nn
33+
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
3334
from torch._dispatch.python import enable_python_dispatcher
3435
from torch._dynamo.debug_utils import aot_graph_input_parser
3536
from torch._dynamo.device_interface import get_interface_for_device
@@ -1409,7 +1410,14 @@ def fn(a, b):
14091410
)
14101411
_, code = run_and_get_code(fn, x, y)
14111412
code = " ".join(code)
1412-
self.assertEqual(
1413+
assert_keywords = ["assert_size_stride", "assert_alignment"]
1414+
filtered_lines = [
1415+
line
1416+
for line in code.splitlines()
1417+
if not any(assert_key in line for assert_key in assert_keywords)
1418+
]
1419+
code = "\n".join(filtered_lines)
1420+
self.assertGreaterEqual(
14131421
code.count("view_dtype" if config.cpp_wrapper else "aten.view"), 3
14141422
)
14151423

@@ -11923,6 +11931,98 @@ def fn(x):
1192311931
check_lowp=False,
1192411932
)
1192511933

11934+
@requires_gpu()
11935+
@skip_if_not_triton
11936+
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
11937+
@config.patch(implicit_fallbacks=True)
11938+
def test_generated_code_has_size_stride_assert(self):
11939+
def foo(x):
11940+
return 3 * x
11941+
11942+
def foo_meta(x):
11943+
return torch.empty_like(x)
11944+
11945+
define_custom_op_for_test("foo", foo, foo_meta)
11946+
11947+
def fn(x):
11948+
a = torch.nn.functional.relu(x)
11949+
b = torch.ops.test.foo(a)
11950+
return b
11951+
11952+
a = torch.randn((16, 32), device=self.device)
11953+
11954+
_, code = run_and_get_code(
11955+
torch.compile(fn),
11956+
a,
11957+
)
11958+
if not is_dynamic_shape_enabled():
11959+
if code and len(code) > 0 and "assert_size_stride(" in code[0]:
11960+
try:
11961+
FileCheck().check_regex(
11962+
r"assert_size_stride\s*\(\s*[^,]+,\s*\([^\)]*\),\s*\([^\)]*\),\s*'[^']+'\s*\)"
11963+
).run(code[0])
11964+
except Exception as e:
11965+
print(f"Failed regex match for assert_size_stride: {e}")
11966+
print(code[0])
11967+
raise e
11968+
else:
11969+
print("Skipping: No assert_size_stride found.")
11970+
11971+
@requires_gpu()
11972+
@skip_if_not_triton
11973+
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
11974+
@config.patch(implicit_fallbacks=True)
11975+
def test_generated_code_has_alignment_assert(self):
11976+
def foo(x):
11977+
return 3 * x
11978+
11979+
def foo_meta(x):
11980+
return torch.empty_like(x)
11981+
11982+
define_custom_op_for_test("foo", foo, foo_meta)
11983+
11984+
def fn(x):
11985+
a = torch.nn.functional.relu(x)
11986+
b = torch.ops.test.foo(a)
11987+
return b
11988+
11989+
a = torch.randn((16, 32), device=self.device)
11990+
11991+
_, code = run_and_get_code(
11992+
torch.compile(fn),
11993+
a,
11994+
)
11995+
if not is_dynamic_shape_enabled():
11996+
if code and len(code) > 0 and "assert_alignment(" in code[0]:
11997+
try:
11998+
FileCheck().check_regex(
11999+
r"assert_alignment\s*\(\s*[^,]+,\s*[^,]+,\s*'[^']+'\s*\)"
12000+
).run(code[0])
12001+
except Exception as e:
12002+
print(f"Failed regex match for assert_alignment: {e}")
12003+
print(code[0])
12004+
raise e
12005+
else:
12006+
print("Skipping: No assert_alignment found.")
12007+
12008+
def test_assert_size_stride_op_name_pass(self):
12009+
tensor = torch.empty((16, 32))
12010+
assert_size_stride(tensor, (16, 32), (32, 1), "torch.ops.dummy.op_name")
12011+
12012+
def test_assert_size_stride_op_name_fail(self):
12013+
tensor = torch.empty((16, 32))
12014+
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
12015+
assert_size_stride(tensor, (32, 64), (32, 1), "torch.ops.dummy.op_name")
12016+
12017+
def test_assert_alignment_op_name_pass(self):
12018+
tensor = torch.empty((16, 32))
12019+
assert_alignment(tensor, 16, "torch.ops.dummy.op_name")
12020+
12021+
def test_assert_alignment_op_name_fail(self):
12022+
tensor = torch.empty((16, 32))
12023+
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
12024+
assert_alignment(tensor, 0, "torch.ops.dummy.op_name")
12025+
1192612026
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
1192712027
@torch._inductor.config.patch(implicit_fallbacks=True)
1192812028
def test_custom_op_unbacked_symints(self):
@@ -13056,12 +13156,12 @@ def f(x):
1305613156
code = run_and_get_triton_code(f, x)
1305713157

1305813158
if is_dynamic_shape_enabled():
13059-
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
13060-
"assert_size_stride(buf2, (s77, s27), (s27, 1))"
13159+
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1)").check(
13160+
"assert_size_stride(buf2, (s77, s27), (s27, 1)"
1306113161
).run(code)
1306213162
else:
13063-
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(
13064-
"assert_size_stride(buf2, (16, 32), (32, 1))"
13163+
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1)").check(
13164+
"assert_size_stride(buf2, (16, 32), (32, 1)"
1306513165
).run(code)
1306613166

1306713167
@requires_cuda

torch/_C/_dynamo/guards.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ def assert_size_stride(
176176
item: torch.Tensor,
177177
size: torch.types._size,
178178
stride: torch.types._size,
179+
op_name: str | None = None,
180+
): ...
181+
def assert_alignment(
182+
item: torch.Tensor,
183+
alignment: int,
184+
op_name: str | None = None,
179185
): ...
180186
def check_obj_id(obj: object, expected: int) -> bool: ...
181187
def check_type_id(obj: object, expected: int) -> bool: ...

torch/_inductor/ir.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5818,26 +5818,42 @@ def codegen_kwargs(self, skip_out=False): # type: ignore[no-untyped-def]
58185818
]
58195819
return kwargs
58205820

5821+
def get_op_name(self) -> str:
5822+
if self.fx_node is not None:
5823+
target = self.fx_node.target
5824+
op_namespace = getattr(target, "__module__", "unknown_namespace")
5825+
op_namespace = op_namespace.replace("._ops.", ".ops.")
5826+
op_namespace = op_namespace.rsplit(".", 1)[0]
5827+
op_name = f"{op_namespace}.{target}"
5828+
else:
5829+
op_name = "unknown_op"
5830+
return op_name
5831+
58215832
def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
58225833
if config.size_asserts and not V.graph.cpp_wrapper:
58235834
# comparing strides for 0 size tensor is tricky. Ignore them for now.
58245835
if sympy_product(self.get_size()) == 0:
58255836
return
58265837
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
58275838
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
5828-
5839+
op_name = self.get_op_name()
58295840
wrapper.writeline(
5830-
f"assert_size_stride({self.get_name()}, {size}, {stride})"
5841+
f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})"
58315842
)
58325843

58335844
def codegen_alignment_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
58345845
if config.alignment_asserts and not V.graph.cpp_wrapper:
58355846
name = self.get_name()
58365847
aligned = name not in V.graph.unaligned_buffers
5848+
op_name = self.get_op_name()
58375849
if aligned:
5838-
wrapper.writeline(f"assert_alignment({name}, {GPU_ALIGN_BYTES})")
5850+
wrapper.writeline(
5851+
f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})"
5852+
)
58395853
else:
5840-
wrapper.writeline(f"# buffer {name} is assumed to be not aligned")
5854+
wrapper.writeline(
5855+
f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
5856+
)
58415857

58425858
def get_group_stride(self): # type: ignore[no-untyped-def]
58435859
"""

torch/csrc/dynamo/guards.cpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -844,21 +844,38 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
844844
PyObject* item = nullptr;
845845
PyObject* size = nullptr;
846846
PyObject* stride = nullptr;
847-
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
847+
const char* op_name = nullptr;
848+
849+
if (!PyArg_ParseTuple(args, "OOO|s", &item, &size, &stride, &op_name)) {
848850
return nullptr;
849851
}
850852
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
851-
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
853+
std::stringstream msg;
854+
msg << "expected Tensor()";
855+
if (op_name) {
856+
msg << " for op: " << op_name;
857+
}
858+
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
852859
return nullptr;
853860
}
854861
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
855-
PyErr_SetString(PyExc_TypeError, "expected tuple()");
862+
std::stringstream msg;
863+
msg << "expected tuple()";
864+
if (op_name) {
865+
msg << " for op: " << op_name;
866+
}
867+
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
856868
return nullptr;
857869
}
858870
at::Tensor tensor = THPVariable_Unpack(item);
859871
int64_t ndim = tensor.ndimension();
860872
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
861-
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
873+
std::stringstream msg;
874+
msg << "wrong number of dimensions" << ndim;
875+
if (op_name) {
876+
msg << " for op: " << op_name;
877+
}
878+
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
862879
return nullptr;
863880
}
864881

@@ -887,6 +904,9 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
887904
}
888905

889906
if (num_errors) {
907+
if (op_name) {
908+
msg << "\nError in op: " << op_name;
909+
}
890910
msg << "\nThis error most often comes from a incorrect fake (aka meta) kernel for a custom op.";
891911
msg << "\nUse torch.library.opcheck to test your custom op.";
892912
msg << "\nSee https://pytorch.org/docs/stable/library.html#torch.library.opcheck";
@@ -904,15 +924,27 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
904924
*/
905925
PyObject* item = nullptr;
906926
unsigned long alignment = 0;
907-
if (!PyArg_ParseTuple(args, "Ok", &item, &alignment)) {
927+
const char* op_name = nullptr;
928+
929+
if (!PyArg_ParseTuple(args, "Ok|s", &item, &alignment, &op_name)) {
908930
return nullptr;
909931
}
910932
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
911-
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
933+
std::stringstream msg;
934+
msg << "expected Tensor()";
935+
if (op_name) {
936+
msg << " for op: " << op_name;
937+
}
938+
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
912939
return nullptr;
913940
}
914941
if (alignment == 0) {
915-
PyErr_SetString(PyExc_AssertionError, "alignment can not be 0");
942+
std::stringstream msg;
943+
msg << "alignment cannot be 0";
944+
if (op_name) {
945+
msg << " in op: " << op_name;
946+
}
947+
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
916948
return nullptr;
917949
}
918950

@@ -922,7 +954,10 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
922954
size_t itemsize = tensor.itemsize();
923955
if (storage_offset * itemsize % alignment != 0) {
924956
std::stringstream msg;
925-
msg << "Expect the tensor to be " << alignment
957+
if (op_name) {
958+
msg << "\nError in op: " << op_name;
959+
}
960+
msg << "\nExpect the tensor to be " << alignment
926961
<< " bytes aligned. Fail due to storage_offset=" << storage_offset
927962
<< " itemsize=" << itemsize;
928963
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());

0 commit comments

Comments
 (0)