@@ -179,22 +179,17 @@ def make_return_consumer(kernel_func):
179179# This function will be called KERNEL_NAME_capi_wrapper and will have a {llvm.emit_c_interface} attribute.
180180# Note there might be other such functions in the final module (gpu-lower-to-nvvm-pipeline somehow also inserts some like this).
181181def make_kernel_wrapper (kernel_func , return_consumer = None ):
182- c_api_compatible_types = [
183- T .memref (element_type = t .element_type ) if MemRefType .isinstance (t ) else t
184- for t in kernel_func .function_type .value .results
185- ]
186-
187182 input_types = kernel_func .function_type .value .inputs
188183
189184 @FuncOp .from_py_func (* input_types , name = f"{ kernel_func .name .value } _capi_wrapper" )
190185 def wrapper (* args , ** _kwargs ):
191186 results = CallOp (kernel_func , list (args )).results
192- c_api_compatible_results = []
193- for i , a in enumerate (results ):
194- if MemRefType .isinstance (a .type ):
195- a = cast (c_api_compatible_types [i ], a )
196- c_api_compatible_results .append (a )
197187 if return_consumer is not None :
188+ c_api_compatible_results = []
189+ for i , a in enumerate (results ):
190+ if MemRefType .isinstance (a .type ):
191+ a = cast (T .memref (element_type = a .type .element_type ), a )
192+ c_api_compatible_results .append (a )
198193 CallOp (return_consumer , c_api_compatible_results )
199194
200195 wrapper_func_op = wrapper .func_op
0 commit comments