Skip to content

Commit 4e87453

Browse files
committed
Select interfaces using the function dialect.
1 parent 79ff3f2 commit 4e87453

File tree

4 files changed

+97
-6
lines changed

4 files changed

+97
-6
lines changed

mlir/lib/Transforms/Utils/InliningUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
107107
Operation *callable, Value argument,
108108
Type targetType,
109109
DictionaryAttr argumentAttrs) const {
110-
auto *handler = getInterfaceFor(call);
110+
auto *handler = getInterfaceFor(callable);
111111
assert(handler && "expected valid dialect handler");
112112
return handler->handleArgument(builder, call, callable, argument, targetType,
113113
argumentAttrs);
@@ -117,7 +117,7 @@ Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
117117
Operation *callable, Value result,
118118
Type targetType,
119119
DictionaryAttr resultAttrs) const {
120-
auto *handler = getInterfaceFor(call);
120+
auto *handler = getInterfaceFor(callable);
121121
assert(handler && "expected valid dialect handler");
122122
return handler->handleResult(builder, call, callable, result, targetType,
123123
resultAttrs);

mlir/test/Transforms/inlining.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,13 @@ func.func @func_with_block_args_location_callee2(%arg0 : i32) {
228228
}
229229

230230
// Check that we can handle argument and result attributes.
231-
func.func @handle_attr_callee_fn_multi_arg(%arg0 : i16, %arg1 : i16 {"test.handle_argument"}) -> (i16 {"test.handle_result"}, i16) {
231+
test.conversion_func_op @handle_attr_callee_fn_multi_arg(%arg0 : i16, %arg1 : i16 {"test.handle_argument"}) -> (i16 {"test.handle_result"}, i16) {
232232
%0 = arith.addi %arg0, %arg1 : i16
233233
%1 = arith.subi %arg0, %arg1 : i16
234-
return %0, %1 : i16, i16
234+
"test.return"(%0, %1) : (i16, i16) -> ()
235235
}
236-
func.func @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) {
237-
return %arg0 : i32
236+
test.conversion_func_op @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) {
237+
"test.return"(%arg0) : (i32) -> ()
238238
}
239239

240240
// CHECK-LABEL: func @inline_handle_attr_call

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/BuiltinOps.h"
2020
#include "mlir/IR/Diagnostics.h"
2121
#include "mlir/IR/ExtensibleDialect.h"
22+
// #include "mlir/IR/FunctionImplementation.h"
2223
#include "mlir/IR/MLIRContext.h"
2324
#include "mlir/IR/OperationSupport.h"
2425
#include "mlir/IR/PatternMatch.h"
@@ -668,6 +669,29 @@ LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
668669
return success();
669670
}
670671

672+
//===----------------------------------------------------------------------===//
673+
// ConversionFuncOp
674+
//===----------------------------------------------------------------------===//
675+
676+
ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
677+
OperationState &result) {
678+
auto buildFuncType =
679+
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
680+
function_interface_impl::VariadicFlag,
681+
std::string &) { return builder.getFunctionType(argTypes, results); };
682+
683+
return function_interface_impl::parseFunctionOp(
684+
parser, result, /*allowVariadic=*/false,
685+
getFunctionTypeAttrName(result.name), buildFuncType,
686+
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
687+
}
688+
689+
void ConversionFuncOp::print(OpAsmPrinter &p) {
690+
function_interface_impl::printFunctionOp(
691+
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
692+
getArgAttrsAttrName(), getResAttrsAttrName());
693+
}
694+
671695
//===----------------------------------------------------------------------===//
672696
// TestFoldToCallOp
673697
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "TestInterfaces.td"
1414
include "mlir/Dialect/DLTI/DLTIBase.td"
1515
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
1616
include "mlir/IR/EnumAttr.td"
17+
include "mlir/IR/FunctionInterfaces.td"
1718
include "mlir/IR/OpBase.td"
1819
include "mlir/IR/OpAsmInterface.td"
1920
include "mlir/IR/PatternBase.td"
@@ -482,6 +483,72 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
482483
}];
483484
}
484485

486+
def ConversionFuncOp : TEST_Op<"conversion_func_op", [CallableOpInterface,
487+
FunctionOpInterface]> {
488+
let arguments = (ins SymbolNameAttr:$sym_name,
489+
TypeAttrOf<FunctionType>:$function_type,
490+
OptionalAttr<DictArrayAttr>:$arg_attrs,
491+
OptionalAttr<DictArrayAttr>:$res_attrs,
492+
OptionalAttr<StrAttr>:$sym_visibility);
493+
let regions = (region AnyRegion:$body);
494+
495+
let extraClassDeclaration = [{
496+
//===------------------------------------------------------------------===//
497+
// CallableOpInterface
498+
//===------------------------------------------------------------------===//
499+
500+
/// Returns the region on the current operation that is callable. This may
501+
/// return null in the case of an external callable object, e.g. an external
502+
/// function.
503+
::mlir::Region *getCallableRegion() {
504+
return isExternal() ? nullptr : &getBody();
505+
}
506+
507+
/// Returns the results types that the callable region produces when
508+
/// executed.
509+
::mlir::ArrayRef<::mlir::Type> getCallableResults() {
510+
return getFunctionType().getResults();
511+
}
512+
513+
/// Returns the argument attributes for all callable region arguments or
514+
/// null if there are none.
515+
::mlir::ArrayAttr getCallableArgAttrs() {
516+
return getArgAttrs().value_or(nullptr);
517+
}
518+
519+
/// Returns the result attributes for all callable region results or
520+
/// null if there are none.
521+
::mlir::ArrayAttr getCallableResAttrs() {
522+
return getResAttrs().value_or(nullptr);
523+
}
524+
525+
//===------------------------------------------------------------------===//
526+
// FunctionOpInterface Methods
527+
//===------------------------------------------------------------------===//
528+
529+
/// Returns the argument types of this async function.
530+
::mlir::ArrayRef<::mlir::Type> getArgumentTypes() {
531+
return getFunctionType().getInputs();
532+
}
533+
534+
/// Returns the result types of this async function.
535+
::mlir::ArrayRef<::mlir::Type> getResultTypes() {
536+
return getFunctionType().getResults();
537+
}
538+
539+
/// Returns the number of results of this async function
540+
unsigned getNumResults() {return getResultTypes().size();}
541+
542+
//===------------------------------------------------------------------===//
543+
// SymbolOpInterface Methods
544+
//===------------------------------------------------------------------===//
545+
546+
bool isDeclaration() { return isExternal(); }
547+
}];
548+
549+
let hasCustomAssemblyFormat = 1;
550+
}
551+
485552
def FunctionalRegionOp : TEST_Op<"functional_region_op",
486553
[CallableOpInterface]> {
487554
let regions = (region AnyRegion:$body);

0 commit comments

Comments
 (0)