|
12 | 12 | import torch.library |
13 | 13 | from torch._dynamo.testing import CompileCounterWithBackend, make_test_cls_with_patches |
14 | 14 | from torch._inductor import metrics |
15 | | -from torch._inductor.codegen.common import device_codegens, register_backend_for_device |
16 | | -from torch._inductor.codegen.cpp import CppScheduling |
17 | 15 | from torch._inductor.codegen.wrapper import PythonWrapperCodegen |
18 | 16 | from torch._inductor.test_case import TestCase |
19 | 17 | from torch._inductor.utils import run_and_get_code |
|
34 | 32 | TEST_WITH_ASAN, |
35 | 33 | TEST_WITH_ROCM, |
36 | 34 | ) |
37 | | -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU |
| 35 | +from torch.testing._internal.inductor_utils import ( |
| 36 | + GPU_TYPE, |
| 37 | + HAS_CPU, |
| 38 | + HAS_GPU, |
| 39 | + patch_inductor_backend, |
| 40 | +) |
38 | 41 |
|
39 | 42 |
|
40 | 43 | # Make the helper files in test/ importable |
@@ -932,23 +935,13 @@ def generate(self, is_inference, *args, **kwargs): |
932 | 935 | _test_wrapper_codegen_statically_known_int_or_none_in_context() |
933 | 936 | return super().generate(is_inference, *args, **kwargs) |
934 | 937 |
|
935 | | - if "cpu" not in device_codegens: |
936 | | - register_backend_for_device("cpu", CppScheduling, PythonWrapperCodegen) |
937 | | - orig_cpu_codegens = device_codegens["cpu"] |
938 | | - try: |
939 | | - register_backend_for_device( |
940 | | - "cpu", orig_cpu_codegens.scheduling, TestWrapperCodegen |
941 | | - ) |
| 938 | + with patch_inductor_backend("cpu", python_wrapper_codegen=TestWrapperCodegen): |
942 | 939 | # Compile each of the functions above, with an example input |
943 | 940 | # that has 5 in the first dimension, but is marked as dynamic |
944 | 941 |
|
945 | 942 | torch.compile(backend="inductor", dynamic=None)(fn_1)(_x) |
946 | 943 | torch.compile(backend="inductor", dynamic=None)(fn_2)(_x) |
947 | 944 | torch.compile(backend="inductor", dynamic=None)(fn_3)(_x) |
948 | | - finally: |
949 | | - register_backend_for_device( |
950 | | - "cpu", orig_cpu_codegens.scheduling, orig_cpu_codegens.wrapper_codegen |
951 | | - ) |
952 | 945 |
|
953 | 946 | @torch._dynamo.config.patch(capture_scalar_outputs=True) |
954 | 947 | def test_item_unbacked_stride_nobreak(self, device): |
|
0 commit comments