Skip to content

Commit e694280

Browse files
marpiochjansel
authored andcommitted
Custom FX pass for inductor's backend registration (pytorch#154841)
This PR is related to RFC pytorch#153532. It is an extension to Inductor's backend registration interface to allow to register custom FX passes by the backend. Pull Request resolved: pytorch#154841 Approved by: https://github.com/jansel Co-authored-by: Jason Ansel <[email protected]>
1 parent c6b4f98 commit e694280

File tree

10 files changed

+213
-36
lines changed

10 files changed

+213
-36
lines changed

test/inductor/test_codecache.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
TensorMetadata,
2828
TensorMetadataAndValues,
2929
)
30-
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
30+
from torch._inductor.custom_graph_pass import (
31+
CustomGraphModulePass,
32+
CustomGraphPass,
33+
get_hash_for_files,
34+
)
3135
from torch._inductor.graph import GraphLowering
3236
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
3337
from torch._inductor.runtime.runtime_utils import cache_dir
@@ -53,6 +57,7 @@
5357
HAS_GPU,
5458
HAS_MULTIGPU,
5559
HAS_TRITON,
60+
patch_inductor_backend,
5661
requires_gpu,
5762
requires_triton,
5863
)
@@ -2183,6 +2188,42 @@ def uuid(self) -> Optional[Union[bytes, str]]:
21832188
pickler.dumps(details3),
21842189
)
21852190

2191+
def test_hash_custom_backend_pass(self):
2192+
"""
2193+
Test CustomGraphModulePass usage.
2194+
"""
2195+
2196+
class TestCustomGraphModulePass(CustomGraphModulePass):
2197+
def __init__(self):
2198+
self._uuid = None
2199+
2200+
def __call__(self, gm: torch.fx.GraphModule) -> None:
2201+
return None
2202+
2203+
def uuid(self) -> Optional[Union[bytes, str]]:
2204+
return self._uuid
2205+
2206+
custom_pass = TestCustomGraphModulePass()
2207+
with patch_inductor_backend("cpu", custom_pass=custom_pass):
2208+
custom_pass._uuid = "1"
2209+
details1 = FxGraphHashDetails(None, [], {}, [])
2210+
details2 = FxGraphHashDetails(None, [], {}, [])
2211+
2212+
custom_pass._uuid = "2"
2213+
details3 = FxGraphHashDetails(None, [], {}, [])
2214+
2215+
gm = torch.fx.GraphModule({}, torch.fx.Graph())
2216+
pickler = FxGraphCachePickler(gm)
2217+
2218+
self.assertEqual(
2219+
pickler.dumps(details1),
2220+
pickler.dumps(details2),
2221+
)
2222+
self.assertNotEqual(
2223+
pickler.dumps(details1),
2224+
pickler.dumps(details3),
2225+
)
2226+
21862227
def test_bypass_unsupported(self):
21872228
"""
21882229
Test _reduce_unsupported

test/inductor/test_custom_post_grad_passes.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88
import torch.fx as fx
99
from torch._dynamo.utils import counters
1010
from torch._inductor import config
11-
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
11+
from torch._inductor.codegen.common import get_custom_backend_pass_for_device
12+
from torch._inductor.custom_graph_pass import (
13+
CustomGraphModulePass,
14+
CustomGraphPass,
15+
get_hash_for_files,
16+
)
1217
from torch._inductor.lowering import lowerings as L
1318
from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass
1419
from torch._inductor.test_case import run_tests, TestCase
1520
from torch.testing._internal.common_utils import IS_LINUX
16-
from torch.testing._internal.inductor_utils import HAS_CPU
21+
from torch.testing._internal.inductor_utils import HAS_CPU, patch_inductor_backend
1722

1823

1924
@config.patch({"freezing": True})
@@ -264,6 +269,35 @@ def f(W, nested_seqs):
264269

265270
inner_test()
266271

272+
def test_custom_backend_pass(self):
273+
class CustomBackendPass(CustomGraphModulePass):
274+
def __init__(self, existing_pass: CustomGraphModulePass = None):
275+
super().__init__()
276+
self.existing_pass = existing_pass
277+
278+
def __call__(self, gm: fx.GraphModule) -> None:
279+
if self.existing_pass:
280+
self.existing_pass(gm)
281+
282+
change_cos_pass(gm.graph)
283+
284+
def uuid(self) -> bytes:
285+
return get_hash_for_files((__file__,))
286+
287+
custom_backend_pass = CustomBackendPass(
288+
get_custom_backend_pass_for_device("cpu")
289+
)
290+
with patch_inductor_backend("cpu", custom_pass=custom_backend_pass):
291+
292+
def g(x):
293+
return x.sin().sin().sin()
294+
295+
def f(x):
296+
return x.cos().cos().cos()
297+
298+
x = torch.randn(8, dtype=torch.float32)
299+
torch.testing.assert_close(torch.compile(f)(x), g(x))
300+
267301

268302
if __name__ == "__main__":
269303
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():

test/inductor/test_torchinductor_dynamic_shapes.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import torch.library
1313
from torch._dynamo.testing import CompileCounterWithBackend, make_test_cls_with_patches
1414
from torch._inductor import metrics
15-
from torch._inductor.codegen.common import device_codegens, register_backend_for_device
16-
from torch._inductor.codegen.cpp import CppScheduling
1715
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
1816
from torch._inductor.test_case import TestCase
1917
from torch._inductor.utils import run_and_get_code
@@ -34,7 +32,12 @@
3432
TEST_WITH_ASAN,
3533
TEST_WITH_ROCM,
3634
)
37-
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
35+
from torch.testing._internal.inductor_utils import (
36+
GPU_TYPE,
37+
HAS_CPU,
38+
HAS_GPU,
39+
patch_inductor_backend,
40+
)
3841

3942

4043
# Make the helper files in test/ importable
@@ -932,23 +935,13 @@ def generate(self, is_inference, *args, **kwargs):
932935
_test_wrapper_codegen_statically_known_int_or_none_in_context()
933936
return super().generate(is_inference, *args, **kwargs)
934937

935-
if "cpu" not in device_codegens:
936-
register_backend_for_device("cpu", CppScheduling, PythonWrapperCodegen)
937-
orig_cpu_codegens = device_codegens["cpu"]
938-
try:
939-
register_backend_for_device(
940-
"cpu", orig_cpu_codegens.scheduling, TestWrapperCodegen
941-
)
938+
with patch_inductor_backend("cpu", python_wrapper_codegen=TestWrapperCodegen):
942939
# Compile each of the functions above, with an example input
943940
# that has 5 in the first dimension, but is marked as dynamic
944941

945942
torch.compile(backend="inductor", dynamic=None)(fn_1)(_x)
946943
torch.compile(backend="inductor", dynamic=None)(fn_2)(_x)
947944
torch.compile(backend="inductor", dynamic=None)(fn_3)(_x)
948-
finally:
949-
register_backend_for_device(
950-
"cpu", orig_cpu_codegens.scheduling, orig_cpu_codegens.wrapper_codegen
951-
)
952945

953946
@torch._dynamo.config.patch(capture_scalar_outputs=True)
954947
def test_item_unbacked_stride_nobreak(self, device):

torch/_inductor/codecache.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from torch._dynamo.exc import SkipFrame
5252
from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed
5353
from torch._inductor import config, exc, metrics
54+
from torch._inductor.codegen.common import custom_backend_passes
5455
from torch._inductor.codegen.cuda import cuda_env
5556
from torch._inductor.codegen.rocm.compile_command import (
5657
rocm_compile_command,
@@ -72,7 +73,11 @@
7273
normalize_path_separator,
7374
)
7475
from torch._inductor.cpu_vec_isa import pick_vec_isa
75-
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
76+
from torch._inductor.custom_graph_pass import (
77+
CustomGraphModulePass,
78+
CustomGraphPass,
79+
CustomGraphPassType,
80+
)
7681
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
7782
from torch._inductor.runtime.compile_tasks import _reload_python_module
7883
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
@@ -891,12 +896,16 @@ def __init__(
891896
config.post_grad_custom_post_pass
892897
)
893898

899+
self.custom_backend_passes = tuple(
900+
map(self._get_custom_pass_detail, custom_backend_passes.values())
901+
)
902+
894903
def _get_custom_pass_detail(
895-
self, custom_pass: CustomGraphPassType
904+
self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass]
896905
) -> Optional[Any]:
897906
if not custom_pass:
898907
return None
899-
assert isinstance(custom_pass, CustomGraphPass)
908+
assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass))
900909
return custom_pass.uuid()
901910

902911

torch/_inductor/codegen/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565

6666
from torch.fx import GraphModule
6767

68+
from ..custom_graph_pass import CustomGraphModulePass
6869
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
6970
from ..loop_body import LoopBody
7071
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
@@ -351,6 +352,7 @@ def cpp_global_scratch(self, idx: int) -> Optional[tuple[str, str]]:
351352

352353

353354
device_op_overrides_dict: dict[str, DeviceOpOverrides] = {}
355+
custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {}
354356

355357

356358
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
@@ -379,10 +381,12 @@ def register_backend_for_device(
379381
device_scheduling: SchedulingConstructor,
380382
device_wrapper_codegen: WrapperConstructor,
381383
device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None,
384+
device_custom_pass: Optional[CustomGraphModulePass] = None,
382385
) -> None:
383386
device_codegens[device] = DeviceCodegen(
384387
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
385388
)
389+
custom_backend_passes[device] = device_custom_pass
386390

387391

388392
class BackendFeature(Enum):
@@ -441,6 +445,10 @@ def get_wrapper_codegen_for_device(
441445
return None
442446

443447

448+
def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
449+
return custom_backend_passes[device] if device in custom_backend_passes else None
450+
451+
444452
@functools.lru_cache(None)
445453
def init_backend_registration() -> None:
446454
from .cpp import CppScheduling

torch/_inductor/compile_fx.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
BoxedBool,
7777
count_tangents,
7878
fresh_inductor_cache,
79+
get_all_devices,
7980
InputType,
8081
is_gpu,
8182
should_assume_input_aligned,
@@ -1901,22 +1902,6 @@ def get_cpp_wrapper_config() -> dict[str, object]:
19011902
}
19021903

19031904

1904-
def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
1905-
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
1906-
input_devices: OrderedSet[torch.device] = OrderedSet(
1907-
node.meta["val"].device
1908-
for node in placeholder_nodes
1909-
if isinstance(node.meta.get("val"), torch.Tensor)
1910-
)
1911-
1912-
out_devices: OrderedSet[torch.device] = OrderedSet(
1913-
arg.meta["val"].device
1914-
for arg in output_node(gm).args[0] # type: ignore[union-attr]
1915-
if isinstance(arg, fx.Node) and isinstance(arg.meta.get("val"), torch.Tensor)
1916-
)
1917-
return input_devices | out_devices
1918-
1919-
19201905
def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]:
19211906
"""
19221907
Returns a cuda device context manager if there is a single device in the graph

torch/_inductor/custom_graph_pass.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,38 @@ def uuid(self) -> Optional[Any]:
5353
"""
5454

5555

56+
class CustomGraphModulePass(ABC):
57+
"""
58+
Implement this interface for custom Graph passes:
59+
60+
1) The __call__() method contains the implementation of the custom pass.
61+
62+
2) The uuid() method enables inductor to cache compiled graphs when your custom
63+
passes are applied. This method can return any identifier as long as it uniquely
64+
identifies your implementation (and can be pickled). The caching logic includes this
65+
identifier in its key calculation, i.e., any new value will effectively invalidate
66+
existing entries. We expect custom passes would typically depend purely on the
67+
textual reprensentation of the implementation. In that case, we recommend using the
68+
'get_hash_for_files' helper below to compute a unique hash from the contents of a
69+
static list of source files, i.e., the source(s) containing the custom pass
70+
implementation. That approach ensures that any change to the implementation will
71+
mean a new uuid.
72+
"""
73+
74+
@abstractmethod
75+
def __call__(self, gm: torch.fx.GraphModule) -> None:
76+
"""
77+
Implementation of the custom pass.
78+
"""
79+
80+
@abstractmethod
81+
def uuid(self) -> Optional[Any]:
82+
"""
83+
Return an ID to uniquely identify your custom pass implementation. Return None
84+
to skip inductor code caching entirely.
85+
"""
86+
87+
5688
CustomGraphPassType: TypeAlias = Optional[
5789
Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]]
5890
]

torch/_inductor/fx_passes/post_grad.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torch.utils._ordered_set import OrderedSet
2323

2424
from .. import config, ir, pattern_matcher
25+
from ..codegen.common import custom_backend_passes
2526
from ..comms import remove_fsdp2_unsharded_param_graph_input_usage
2627
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
2728
from ..lowering import lowerings as L
@@ -48,6 +49,7 @@
4849
)
4950
from ..utils import (
5051
decode_device,
52+
get_all_devices,
5153
get_gpu_type,
5254
is_gpu,
5355
is_pointwise_use,
@@ -182,6 +184,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
182184

183185
fake_tensor_updater.incremental_update()
184186

187+
for device, custom_backend_pass in custom_backend_passes.items():
188+
if custom_backend_pass is not None:
189+
gm_devices = [d.type for d in get_all_devices(gm)]
190+
if device in gm_devices:
191+
pass_name = "custom_backend_passes_" + device
192+
GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass)
193+
185194
# Keep these last, since they introduces mutation. Look at
186195
# ./fx_passes/README.md for a discussion of mutation invariants.
187196
GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(

torch/_inductor/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,25 @@ def output_node(gm: torch.fx.GraphModule) -> Node:
994994
return last_node
995995

996996

997+
def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
998+
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
999+
input_devices: OrderedSet[torch.device] = OrderedSet(
1000+
node.meta["val"].device
1001+
for node in placeholder_nodes
1002+
if isinstance(node.meta.get("val"), torch.Tensor)
1003+
)
1004+
1005+
out_arg = output_node(gm).args[0] # type: ignore[union-attr]
1006+
out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,)
1007+
out_devices: OrderedSet[torch.device] = OrderedSet(
1008+
arg.meta["val"].device
1009+
for arg in out_args
1010+
if isinstance(arg, torch.fx.Node)
1011+
and isinstance(arg.meta.get("val"), torch.Tensor)
1012+
)
1013+
return input_devices | out_devices
1014+
1015+
9971016
_registered_caches: list[Any] = []
9981017

9991018

0 commit comments

Comments
 (0)