77import sys
88import textwrap
99from itertools import chain , count
10- from typing import Callable , Optional , Protocol , TYPE_CHECKING , Union
10+ from typing import Any , Callable , Optional , Protocol , TYPE_CHECKING , Union
1111
1212import sympy
1313
1414import torch
15+ import torch ._higher_order_ops .torchbind
1516import torch ._inductor .async_compile # noqa: F401 required to warm up AsyncCompile pools
1617import torch ._ops
1718from torch ._inductor .runtime .runtime_utils import dynamo_timed
3839
3940 from ..graph import GraphLowering
4041
42+ # At most, the list nesting can go one layer deep.
43+ _OUTPUT_ARGS_TYPE = list [Union [Optional [str ], list [Optional [str ]]]]
44+
4145
4246class HasWriteLine (Protocol ):
4347 def writeline (self , line : Union [LineContext , DeferredLineBase , str ]) -> None : ...
@@ -1880,17 +1884,18 @@ def codegen_while_loop(self, while_loop):
18801884
18811885 def generate_extern_kernel_args_decl_if_needed (
18821886 self ,
1883- op_overload ,
1884- raw_args ,
1885- output_args : Optional [ list [ str ]] = None ,
1886- raw_outputs : Optional [ list [ ir .Buffer ]] = None ,
1887+ op_overload : Union [ torch . _ops . OpOverload , torch . _ops . HigherOrderOperator ] ,
1888+ raw_args : Sequence [ Any ] ,
1889+ output_args : _OUTPUT_ARGS_TYPE ,
1890+ raw_outputs : Sequence [ ir .Buffer ],
18871891 ):
18881892 schema = None
18891893 if isinstance (op_overload , torch ._higher_order_ops .torchbind .CallTorchBind ):
18901894 obj = raw_args [0 ]
18911895 method = raw_args [1 ]
18921896 schema = op_overload .schema (obj , method )
18931897 else :
1898+ assert isinstance (op_overload , torch ._ops .OpOverload ), type (op_overload )
18941899 schema = op_overload ._schema
18951900 assert schema is not None
18961901 arg_types = [x .real_type for x in schema .arguments ]
@@ -1986,7 +1991,9 @@ def fill_args(arg, arg_type):
19861991 else :
19871992 fill_args (arg , arg_type )
19881993
1989- def fill_output_arg (arg , return_type , is_mutated_output : bool ):
1994+ def fill_output_arg (
1995+ arg : str , return_type : torch .JitType , is_mutated_output : bool
1996+ ) -> None :
19901997 if isinstance (return_type , torch .TensorType ):
19911998 if not is_mutated_output :
19921999 self .writeline (f"AtenTensorHandle { arg } _handle; // output buffer" )
@@ -2021,8 +2028,9 @@ def fill_output_arg(arg, return_type, is_mutated_output: bool):
20212028 # None output is supported, but Optional return types are not yet supported
20222029 if output_arg is None :
20232030 continue
2024- elif isinstance (output_arg , ( list , tuple ) ):
2031+ elif isinstance (output_arg , list ):
20252032 for out in output_arg :
2033+ assert out is not None , out
20262034 fill_output_arg (
20272035 out ,
20282036 torch .TensorType .get (),
@@ -2041,73 +2049,73 @@ def generate_fallback_kernel_with_runtime_lookup(
20412049 self ,
20422050 buf_name : str ,
20432051 python_kernel_name : str ,
2044- cpp_kernel_name : str ,
2045- codegen_args : list [str ],
2046- op_overload : Optional [torch ._ops .OpOverload ] = None ,
2047- raw_args = None ,
2048- outputs = None ,
2049- ):
2050- def extract_output_name (out ):
2052+ codegen_args : Sequence [str ],
2053+ op_overload : Union [torch ._ops .OpOverload , torch ._ops .HigherOrderOperator ],
2054+ raw_args : Sequence [Any ],
2055+ outputs : Sequence [ir .Buffer ],
2056+ ) -> None :
2057+ """Generate a call to a kernel not contained in the C-shim. This results in
2058+ different code paths for AOT Inductor vs cpp_wrapper Inductor mode."""
2059+
2060+ def extract_output_name (
2061+ out : Optional [Union [ir .Buffer , Sequence [ir .Buffer ]]],
2062+ ) -> Union [Optional [str ], _OUTPUT_ARGS_TYPE ]:
20512063 if out is None :
20522064 return None
2053- elif isinstance (out , (ir .MultiOutput , ir ._CollectiveKernel )):
2065+ if isinstance (out , (ir .MultiOutput , ir ._CollectiveKernel )):
20542066 return out .get_name ()
2055- elif isinstance (out , ir .MutationOutput ):
2067+ if isinstance (out , ir .MutationOutput ):
20562068 mutated_buf_names = out .get_mutation_names ()
20572069 assert (
20582070 isinstance (mutated_buf_names , list ) and len (mutated_buf_names ) == 1
20592071 ), "Expect only one mutated buffer in MutationOutput"
20602072 return mutated_buf_names [0 ]
2061- elif isinstance (out , (list , tuple )):
2062- return type (out )(extract_output_name (o ) for o in out )
2063- else :
2064- raise AssertionError (f"Unexpected output: { type (out )} " )
2073+ if isinstance (out , (list , tuple )):
2074+ return [extract_output_name (o ) for o in out ] # type: ignore[misc]
2075+ raise AssertionError (f"Unexpected output: { type (out )} " )
2076+
2077+ if isinstance (op_overload , torch ._ops .HigherOrderOperator ):
2078+ assert isinstance (
2079+ op_overload , torch ._higher_order_ops .torchbind .CallTorchBind
2080+ ), type (op_overload )
2081+ assert len (raw_args ) > 1
2082+ obj = raw_args [0 ]
2083+ method = raw_args [1 ]
2084+ return_schema = op_overload .schema (obj , method ).returns
2085+ else :
2086+ return_schema = op_overload ._schema .returns
20652087
20662088 # output_args has the same pytree structure as outputs
2067-
2068- return_schema = None
2069- if op_overload :
2070- if isinstance (op_overload , torch ._higher_order_ops .torchbind .CallTorchBind ):
2071- assert raw_args is not None
2072- assert len (raw_args ) > 1
2073- obj = raw_args [0 ]
2074- method = raw_args [1 ]
2075- return_schema = op_overload .schema (obj , method ).returns
2076- else :
2077- return_schema = op_overload ._schema .returns
2078- if op_overload and not return_schema :
2089+ if not return_schema :
20792090 # kernel does not return a value
2080- output_args = []
2081- elif outputs is None :
2082- # outputs is not specified, the default is to write to buf_name
2083- output_args = [buf_name ]
2091+ output_args : _OUTPUT_ARGS_TYPE = []
2092+ elif isinstance (output_name := extract_output_name (outputs ), str ):
2093+ output_args = [output_name ]
20842094 else :
2085- output_args = extract_output_name (outputs )
2086- if isinstance (output_args , str ):
2087- output_args = [output_args ]
2095+ # If the schema indicates a return value, we should have a non-None value by
2096+ # this point.
2097+ assert isinstance (output_name , list ), type (output_name )
2098+ output_args = output_name
20882099
2100+ # In AOT mode, we use a ProxyExecutor to run fallback kernels.
20892101 if V .graph .aot_mode :
2090- assert op_overload is not None
2091- assert raw_args is not None
2092- assert output_args is not None
2093-
2094- return self .generate_fallback_kernel_with_runtime_lookup_aot (
2102+ self .generate_fallback_kernel_with_runtime_lookup_aot (
20952103 op_overload ,
20962104 raw_args ,
20972105 output_args ,
20982106 outputs ,
20992107 )
2100- else :
2101- return self . generate_fallback_kernel_with_runtime_lookup_jit (
2102- buf_name ,
2103- python_kernel_name ,
2104- cpp_kernel_name ,
2105- codegen_args ,
2106- op_overload ,
2107- raw_args ,
2108- output_args , # type: ignore[arg-type]
2109- outputs ,
2110- )
2108+ return
2109+
2110+ assert isinstance ( op_overload , torch . _ops . OpOverload ), type ( op_overload )
2111+ self . generate_fallback_kernel_with_runtime_lookup_jit (
2112+ buf_name ,
2113+ python_kernel_name ,
2114+ op_overload ,
2115+ raw_args ,
2116+ output_args , # type: ignore[arg-type]
2117+ outputs ,
2118+ )
21112119
21122120 def generate_scoped_gil_acquire (self , declarations_before_scope , lines_in_scope ):
21132121 scoped_lines = IndentedBuffer ()
@@ -2256,19 +2264,19 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
22562264 self ,
22572265 buf_name : str ,
22582266 python_kernel_name : str ,
2259- cpp_kernel_name : str ,
2260- codegen_args : list [str ],
2261- op_overload : Optional [torch ._ops .OpOverload ] = None ,
2262- raw_args = None ,
2263- output_args : Optional [list [Optional [str ]]] = None ,
2264- raw_outputs : Optional [list [ir .Buffer ]] = None ,
2265- ):
2266- # In the JIT mode, because of the ABI-compatible requirement, we can't directly call
2267- # c10::Dispatcher to find the custom op and call it. Instead, we go back to Python
2268- # to invoke this custom op.
2267+ op_overload : torch ._ops .OpOverload ,
2268+ raw_args : Sequence [Any ],
2269+ output_args : Sequence [Optional [str ]],
2270+ raw_outputs : Sequence [ir .Buffer ],
2271+ ) -> None :
2272+ """Generate fallback kernel calls with runtime (non-AOT) dispatch. This can
2273+ only be called in cpp_wrapper mode, and assumes that the input is a non-None
2274+ OpOverload.
2275+
2276+ This function calls into Python to dispatch, which allows it to handle datatypes
2277+ that cannot be contained in StableIValue, at the cost of some performance."""
22692278 self .load_custom_op_wrapper ()
22702279
2271- assert output_args is not None , "output_args should not be None"
22722280 num_args = len (raw_args )
22732281 py_args_var = f"py_args_{ next (self .arg_var_id )} "
22742282 # First arg is always the python op name
@@ -2282,8 +2290,6 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
22822290 """
22832291 )
22842292
2285- assert op_overload is not None , "op_overload should not be None"
2286-
22872293 for idx , (raw_arg , schema_arg ) in enumerate (
22882294 zip (raw_args , op_overload ._schema .arguments )
22892295 ):
@@ -2334,11 +2340,11 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
23342340
23352341 def generate_fallback_kernel_with_runtime_lookup_aot (
23362342 self ,
2337- op_overload ,
2338- raw_args , # contains both args and flatten kwargs
2339- output_args : Optional [ list [ str ]] = None ,
2340- raw_outputs : Optional [ list [ ir .Buffer ]] = None ,
2341- ):
2343+ op_overload : Union [ torch . _ops . OpOverload , torch . _ops . HigherOrderOperator ] ,
2344+ raw_args : Sequence [ Any ],
2345+ output_args : _OUTPUT_ARGS_TYPE ,
2346+ raw_outputs : Sequence [ ir .Buffer ],
2347+ ) -> None :
23422348 (
23432349 tensor_call_args ,
23442350 int_call_args ,
0 commit comments