|
30 | 30 | import torch._dynamo.config as dynamo_config |
31 | 31 | import torch._inductor.aoti_eager |
32 | 32 | import torch.nn as nn |
| 33 | +from torch._C._dynamo.guards import assert_alignment, assert_size_stride |
33 | 34 | from torch._dispatch.python import enable_python_dispatcher |
34 | 35 | from torch._dynamo.debug_utils import aot_graph_input_parser |
35 | 36 | from torch._dynamo.device_interface import get_interface_for_device |
@@ -1409,7 +1410,14 @@ def fn(a, b): |
1409 | 1410 | ) |
1410 | 1411 | _, code = run_and_get_code(fn, x, y) |
1411 | 1412 | code = " ".join(code) |
1412 | | - self.assertEqual( |
| 1413 | + assert_keywords = ["assert_size_stride", "assert_alignment"] |
| 1414 | + filtered_lines = [ |
| 1415 | + line |
| 1416 | + for line in code.splitlines() |
| 1417 | + if not any(assert_key in line for assert_key in assert_keywords) |
| 1418 | + ] |
| 1419 | + code = "\n".join(filtered_lines) |
| 1420 | + self.assertGreaterEqual( |
1413 | 1421 | code.count("view_dtype" if config.cpp_wrapper else "aten.view"), 3 |
1414 | 1422 | ) |
1415 | 1423 |
|
@@ -11923,6 +11931,98 @@ def fn(x): |
11923 | 11931 | check_lowp=False, |
11924 | 11932 | ) |
11925 | 11933 |
|
| 11934 | + @requires_gpu() |
| 11935 | + @skip_if_not_triton |
| 11936 | + @skip_if_cpp_wrapper("skip cpp_wrapper tests") |
| 11937 | + @config.patch(implicit_fallbacks=True) |
| 11938 | + def test_generated_code_has_size_stride_assert(self): |
| 11939 | + def foo(x): |
| 11940 | + return 3 * x |
| 11941 | + |
| 11942 | + def foo_meta(x): |
| 11943 | + return torch.empty_like(x) |
| 11944 | + |
| 11945 | + define_custom_op_for_test("foo", foo, foo_meta) |
| 11946 | + |
| 11947 | + def fn(x): |
| 11948 | + a = torch.nn.functional.relu(x) |
| 11949 | + b = torch.ops.test.foo(a) |
| 11950 | + return b |
| 11951 | + |
| 11952 | + a = torch.randn((16, 32), device=self.device) |
| 11953 | + |
| 11954 | + _, code = run_and_get_code( |
| 11955 | + torch.compile(fn), |
| 11956 | + a, |
| 11957 | + ) |
| 11958 | + if not is_dynamic_shape_enabled(): |
| 11959 | + if code and len(code) > 0 and "assert_size_stride(" in code[0]: |
| 11960 | + try: |
| 11961 | + FileCheck().check_regex( |
| 11962 | + r"assert_size_stride\s*\(\s*[^,]+,\s*\([^\)]*\),\s*\([^\)]*\),\s*'[^']+'\s*\)" |
| 11963 | + ).run(code[0]) |
| 11964 | + except Exception as e: |
| 11965 | + print(f"Failed regex match for assert_size_stride: {e}") |
| 11966 | + print(code[0]) |
| 11967 | + raise e |
| 11968 | + else: |
| 11969 | + print("Skipping: No assert_size_stride found.") |
| 11970 | + |
| 11971 | + @requires_gpu() |
| 11972 | + @skip_if_not_triton |
| 11973 | + @skip_if_cpp_wrapper("skip cpp_wrapper tests") |
| 11974 | + @config.patch(implicit_fallbacks=True) |
| 11975 | + def test_generated_code_has_alignment_assert(self): |
| 11976 | + def foo(x): |
| 11977 | + return 3 * x |
| 11978 | + |
| 11979 | + def foo_meta(x): |
| 11980 | + return torch.empty_like(x) |
| 11981 | + |
| 11982 | + define_custom_op_for_test("foo", foo, foo_meta) |
| 11983 | + |
| 11984 | + def fn(x): |
| 11985 | + a = torch.nn.functional.relu(x) |
| 11986 | + b = torch.ops.test.foo(a) |
| 11987 | + return b |
| 11988 | + |
| 11989 | + a = torch.randn((16, 32), device=self.device) |
| 11990 | + |
| 11991 | + _, code = run_and_get_code( |
| 11992 | + torch.compile(fn), |
| 11993 | + a, |
| 11994 | + ) |
| 11995 | + if not is_dynamic_shape_enabled(): |
| 11996 | + if code and len(code) > 0 and "assert_alignment(" in code[0]: |
| 11997 | + try: |
| 11998 | + FileCheck().check_regex( |
| 11999 | + r"assert_alignment\s*\(\s*[^,]+,\s*[^,]+,\s*'[^']+'\s*\)" |
| 12000 | + ).run(code[0]) |
| 12001 | + except Exception as e: |
| 12002 | + print(f"Failed regex match for assert_alignment: {e}") |
| 12003 | + print(code[0]) |
| 12004 | + raise e |
| 12005 | + else: |
| 12006 | + print("Skipping: No assert_alignment found.") |
| 12007 | + |
| 12008 | + def test_assert_size_stride_op_name_pass(self): |
| 12009 | + tensor = torch.empty((16, 32)) |
| 12010 | + assert_size_stride(tensor, (16, 32), (32, 1), "torch.ops.dummy.op_name") |
| 12011 | + |
| 12012 | + def test_assert_size_stride_op_name_fail(self): |
| 12013 | + tensor = torch.empty((16, 32)) |
| 12014 | + with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"): |
| 12015 | + assert_size_stride(tensor, (32, 64), (32, 1), "torch.ops.dummy.op_name") |
| 12016 | + |
| 12017 | + def test_assert_alignment_op_name_pass(self): |
| 12018 | + tensor = torch.empty((16, 32)) |
| 12019 | + assert_alignment(tensor, 16, "torch.ops.dummy.op_name") |
| 12020 | + |
| 12021 | + def test_assert_alignment_op_name_fail(self): |
| 12022 | + tensor = torch.empty((16, 32)) |
| 12023 | + with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"): |
| 12024 | + assert_alignment(tensor, 0, "torch.ops.dummy.op_name") |
| 12025 | + |
11926 | 12026 | @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) |
11927 | 12027 | @torch._inductor.config.patch(implicit_fallbacks=True) |
11928 | 12028 | def test_custom_op_unbacked_symints(self): |
@@ -13056,12 +13156,12 @@ def f(x): |
13056 | 13156 | code = run_and_get_triton_code(f, x) |
13057 | 13157 |
|
13058 | 13158 | if is_dynamic_shape_enabled(): |
13059 | | - FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check( |
13060 | | - "assert_size_stride(buf2, (s77, s27), (s27, 1))" |
| 13159 | + FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1)").check( |
| 13160 | + "assert_size_stride(buf2, (s77, s27), (s27, 1)" |
13061 | 13161 | ).run(code) |
13062 | 13162 | else: |
13063 | | - FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check( |
13064 | | - "assert_size_stride(buf2, (16, 32), (32, 1))" |
| 13163 | + FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1)").check( |
| 13164 | + "assert_size_stride(buf2, (16, 32), (32, 1)" |
13065 | 13165 | ).run(code) |
13066 | 13166 |
|
13067 | 13167 | @requires_cuda |
|
0 commit comments