Skip to content

Commit 65b1aed

Browse files
benjaminglass1pytorchmergebot
authored andcommitted
[Inductor] Improve typing, and prepare for ABI-compatible AOTI C-shim dispatching (pytorch#154371)
Prepares for the next PR in the stack by tightening up typing on a `cpp_wrapper` interface that's only used in one (well-typed) place, as well as downstream effects of that change. In particular, this enabled: 1. removing a number of now clearly unnecessary asserts 2. adding a few more targeted asserts to validate the code's current assumptions 3. removing some unneeded control flow in several functions As far as I can tell, this PR should be functionally neutral. One argument was removed from a `cpp_wrapper` public API, but that argument was unused, and only had a single callsite. Pull Request resolved: pytorch#154371 Approved by: https://github.com/desertfire
1 parent 3e05a48 commit 65b1aed

File tree

5 files changed

+132
-175
lines changed

5 files changed

+132
-175
lines changed

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 79 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import sys
88
import textwrap
99
from itertools import chain, count
10-
from typing import Callable, Optional, Protocol, TYPE_CHECKING, Union
10+
from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING, Union
1111

1212
import sympy
1313

1414
import torch
15+
import torch._higher_order_ops.torchbind
1516
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
1617
import torch._ops
1718
from torch._inductor.runtime.runtime_utils import dynamo_timed
@@ -38,6 +39,9 @@
3839

3940
from ..graph import GraphLowering
4041

42+
# At most, the list nesting can go one layer deep.
43+
_OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]]
44+
4145

4246
class HasWriteLine(Protocol):
4347
def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ...
@@ -1880,17 +1884,18 @@ def codegen_while_loop(self, while_loop):
18801884

18811885
def generate_extern_kernel_args_decl_if_needed(
18821886
self,
1883-
op_overload,
1884-
raw_args,
1885-
output_args: Optional[list[str]] = None,
1886-
raw_outputs: Optional[list[ir.Buffer]] = None,
1887+
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
1888+
raw_args: Sequence[Any],
1889+
output_args: _OUTPUT_ARGS_TYPE,
1890+
raw_outputs: Sequence[ir.Buffer],
18871891
):
18881892
schema = None
18891893
if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind):
18901894
obj = raw_args[0]
18911895
method = raw_args[1]
18921896
schema = op_overload.schema(obj, method)
18931897
else:
1898+
assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload)
18941899
schema = op_overload._schema
18951900
assert schema is not None
18961901
arg_types = [x.real_type for x in schema.arguments]
@@ -1986,7 +1991,9 @@ def fill_args(arg, arg_type):
19861991
else:
19871992
fill_args(arg, arg_type)
19881993

1989-
def fill_output_arg(arg, return_type, is_mutated_output: bool):
1994+
def fill_output_arg(
1995+
arg: str, return_type: torch.JitType, is_mutated_output: bool
1996+
) -> None:
19901997
if isinstance(return_type, torch.TensorType):
19911998
if not is_mutated_output:
19921999
self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer")
@@ -2021,8 +2028,9 @@ def fill_output_arg(arg, return_type, is_mutated_output: bool):
20212028
# None output is supported, but Optional return types are not yet supported
20222029
if output_arg is None:
20232030
continue
2024-
elif isinstance(output_arg, (list, tuple)):
2031+
elif isinstance(output_arg, list):
20252032
for out in output_arg:
2033+
assert out is not None, out
20262034
fill_output_arg(
20272035
out,
20282036
torch.TensorType.get(),
@@ -2041,73 +2049,73 @@ def generate_fallback_kernel_with_runtime_lookup(
20412049
self,
20422050
buf_name: str,
20432051
python_kernel_name: str,
2044-
cpp_kernel_name: str,
2045-
codegen_args: list[str],
2046-
op_overload: Optional[torch._ops.OpOverload] = None,
2047-
raw_args=None,
2048-
outputs=None,
2049-
):
2050-
def extract_output_name(out):
2052+
codegen_args: Sequence[str],
2053+
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
2054+
raw_args: Sequence[Any],
2055+
outputs: Sequence[ir.Buffer],
2056+
) -> None:
2057+
"""Generate a call to a kernel not contained in the C-shim. This results in
2058+
different code paths for AOT Inductor vs cpp_wrapper Inductor mode."""
2059+
2060+
def extract_output_name(
2061+
out: Optional[Union[ir.Buffer, Sequence[ir.Buffer]]],
2062+
) -> Union[Optional[str], _OUTPUT_ARGS_TYPE]:
20512063
if out is None:
20522064
return None
2053-
elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
2065+
if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
20542066
return out.get_name()
2055-
elif isinstance(out, ir.MutationOutput):
2067+
if isinstance(out, ir.MutationOutput):
20562068
mutated_buf_names = out.get_mutation_names()
20572069
assert (
20582070
isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1
20592071
), "Expect only one mutated buffer in MutationOutput"
20602072
return mutated_buf_names[0]
2061-
elif isinstance(out, (list, tuple)):
2062-
return type(out)(extract_output_name(o) for o in out)
2063-
else:
2064-
raise AssertionError(f"Unexpected output: {type(out)}")
2073+
if isinstance(out, (list, tuple)):
2074+
return [extract_output_name(o) for o in out] # type: ignore[misc]
2075+
raise AssertionError(f"Unexpected output: {type(out)}")
2076+
2077+
if isinstance(op_overload, torch._ops.HigherOrderOperator):
2078+
assert isinstance(
2079+
op_overload, torch._higher_order_ops.torchbind.CallTorchBind
2080+
), type(op_overload)
2081+
assert len(raw_args) > 1
2082+
obj = raw_args[0]
2083+
method = raw_args[1]
2084+
return_schema = op_overload.schema(obj, method).returns
2085+
else:
2086+
return_schema = op_overload._schema.returns
20652087

20662088
# output_args has the same pytree structure as outputs
2067-
2068-
return_schema = None
2069-
if op_overload:
2070-
if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind):
2071-
assert raw_args is not None
2072-
assert len(raw_args) > 1
2073-
obj = raw_args[0]
2074-
method = raw_args[1]
2075-
return_schema = op_overload.schema(obj, method).returns
2076-
else:
2077-
return_schema = op_overload._schema.returns
2078-
if op_overload and not return_schema:
2089+
if not return_schema:
20792090
# kernel does not return a value
2080-
output_args = []
2081-
elif outputs is None:
2082-
# outputs is not specified, the default is to write to buf_name
2083-
output_args = [buf_name]
2091+
output_args: _OUTPUT_ARGS_TYPE = []
2092+
elif isinstance(output_name := extract_output_name(outputs), str):
2093+
output_args = [output_name]
20842094
else:
2085-
output_args = extract_output_name(outputs)
2086-
if isinstance(output_args, str):
2087-
output_args = [output_args]
2095+
# If the schema indicates a return value, we should have a non-None value by
2096+
# this point.
2097+
assert isinstance(output_name, list), type(output_name)
2098+
output_args = output_name
20882099

2100+
# In AOT mode, we use a ProxyExecutor to run fallback kernels.
20892101
if V.graph.aot_mode:
2090-
assert op_overload is not None
2091-
assert raw_args is not None
2092-
assert output_args is not None
2093-
2094-
return self.generate_fallback_kernel_with_runtime_lookup_aot(
2102+
self.generate_fallback_kernel_with_runtime_lookup_aot(
20952103
op_overload,
20962104
raw_args,
20972105
output_args,
20982106
outputs,
20992107
)
2100-
else:
2101-
return self.generate_fallback_kernel_with_runtime_lookup_jit(
2102-
buf_name,
2103-
python_kernel_name,
2104-
cpp_kernel_name,
2105-
codegen_args,
2106-
op_overload,
2107-
raw_args,
2108-
output_args, # type: ignore[arg-type]
2109-
outputs,
2110-
)
2108+
return
2109+
2110+
assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload)
2111+
self.generate_fallback_kernel_with_runtime_lookup_jit(
2112+
buf_name,
2113+
python_kernel_name,
2114+
op_overload,
2115+
raw_args,
2116+
output_args, # type: ignore[arg-type]
2117+
outputs,
2118+
)
21112119

21122120
def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope):
21132121
scoped_lines = IndentedBuffer()
@@ -2256,19 +2264,19 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
22562264
self,
22572265
buf_name: str,
22582266
python_kernel_name: str,
2259-
cpp_kernel_name: str,
2260-
codegen_args: list[str],
2261-
op_overload: Optional[torch._ops.OpOverload] = None,
2262-
raw_args=None,
2263-
output_args: Optional[list[Optional[str]]] = None,
2264-
raw_outputs: Optional[list[ir.Buffer]] = None,
2265-
):
2266-
# In the JIT mode, because of the ABI-compatible requirement, we can't directly call
2267-
# c10::Dispatcher to find the custom op and call it. Instead, we go back to Python
2268-
# to invoke this custom op.
2267+
op_overload: torch._ops.OpOverload,
2268+
raw_args: Sequence[Any],
2269+
output_args: Sequence[Optional[str]],
2270+
raw_outputs: Sequence[ir.Buffer],
2271+
) -> None:
2272+
"""Generate fallback kernel calls with runtime (non-AOT) dispatch. This can
2273+
only be called in cpp_wrapper mode, and assumes that the input is a non-None
2274+
OpOverload.
2275+
2276+
This function calls into Python to dispatch, which allows it to handle datatypes
2277+
that cannot be contained in StableIValue, at the cost of some performance."""
22692278
self.load_custom_op_wrapper()
22702279

2271-
assert output_args is not None, "output_args should not be None"
22722280
num_args = len(raw_args)
22732281
py_args_var = f"py_args_{next(self.arg_var_id)}"
22742282
# First arg is always the python op name
@@ -2282,8 +2290,6 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
22822290
"""
22832291
)
22842292

2285-
assert op_overload is not None, "op_overload should not be None"
2286-
22872293
for idx, (raw_arg, schema_arg) in enumerate(
22882294
zip(raw_args, op_overload._schema.arguments)
22892295
):
@@ -2334,11 +2340,11 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
23342340

23352341
def generate_fallback_kernel_with_runtime_lookup_aot(
23362342
self,
2337-
op_overload,
2338-
raw_args, # contains both args and flatten kwargs
2339-
output_args: Optional[list[str]] = None,
2340-
raw_outputs: Optional[list[ir.Buffer]] = None,
2341-
):
2343+
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
2344+
raw_args: Sequence[Any],
2345+
output_args: _OUTPUT_ARGS_TYPE,
2346+
raw_outputs: Sequence[ir.Buffer],
2347+
) -> None:
23422348
(
23432349
tensor_call_args,
23442350
int_call_args,

torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
2-
from typing import Callable, Optional
2+
from collections.abc import Sequence
3+
from typing import Any, Callable, Optional, Union
34

45
import sympy
56

@@ -749,57 +750,16 @@ def generate_fallback_kernel_with_runtime_lookup(
749750
self,
750751
buf_name: str,
751752
python_kernel_name: str,
752-
cpp_kernel_name: str,
753-
codegen_args: list[str],
754-
op_overload: Optional[torch._ops.OpOverload] = None,
755-
raw_args=None,
756-
outputs=None,
757-
):
753+
codegen_args: Sequence[str],
754+
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
755+
raw_args: Sequence[Any],
756+
outputs: Sequence[ir.Buffer],
757+
) -> None:
758758
# No stack allocation when there is a fallback op
759759
self.allow_stack_allocation = False
760-
761-
def extract_output_name(out):
762-
if out is None:
763-
return None
764-
elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
765-
return out.get_name()
766-
elif isinstance(out, (list, tuple)):
767-
return type(out)(extract_output_name(o) for o in out)
768-
else:
769-
raise AssertionError(f"Unexpected output: {type(out)}")
770-
771-
# output_args has the same pytree structure as outputs
772-
output_args = None
773-
if outputs is None:
774-
# outputs is not specified, the default is to write to buf_name
775-
output_args = [buf_name]
776-
else:
777-
output_args = extract_output_name(outputs)
778-
if isinstance(output_args, str):
779-
output_args = [output_args]
780-
781-
if V.graph.aot_mode:
782-
assert op_overload is not None
783-
assert raw_args is not None
784-
assert outputs is not None
785-
786-
return self.generate_fallback_kernel_with_runtime_lookup_aot(
787-
op_overload,
788-
raw_args,
789-
output_args,
790-
outputs,
791-
)
792-
else:
793-
return self.generate_fallback_kernel_with_runtime_lookup_jit(
794-
buf_name,
795-
python_kernel_name,
796-
cpp_kernel_name,
797-
codegen_args,
798-
op_overload,
799-
raw_args,
800-
output_args, # type: ignore[arg-type]
801-
outputs,
802-
)
760+
super().generate_fallback_kernel_with_runtime_lookup(
761+
buf_name, python_kernel_name, codegen_args, op_overload, raw_args, outputs
762+
)
803763

804764
def codegen_device_copy(self, src, dst, non_blocking: bool):
805765
# aoti_torch_tensor_copy_ takes AtenTensorHandle as input,

torch/_inductor/codegen/wrapper.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,12 +1399,11 @@ def generate_fallback_kernel_with_runtime_lookup(
13991399
self,
14001400
buf_name: str,
14011401
python_kernel_name: str,
1402-
cpp_kernel_name: str,
1403-
codegen_args: list[str],
1404-
op_overload: Optional[torch._ops.OpOverload] = None,
1405-
raw_args=None,
1406-
outputs=None,
1407-
):
1402+
codegen_args: Sequence[str],
1403+
op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator],
1404+
raw_args: Sequence[Any],
1405+
outputs: Sequence[ir.Buffer],
1406+
) -> None:
14081407
self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(codegen_args)})")
14091408

14101409
def generate(self, is_inference):

0 commit comments

Comments
 (0)