Skip to content

Commit 79ff3f2

Browse files
committed
Address review comments.
1 parent 424ce0f commit 79ff3f2

File tree

8 files changed

+92
-26
lines changed

8 files changed

+92
-26
lines changed

mlir/include/mlir/Dialect/Async/IR/AsyncOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ def Async_FuncOp : Async_Op<"func",
168168
ArrayRef<Type> getCallableResults() { return getFunctionType()
169169
.getResults(); }
170170

171+
/// Returns the argument attributes for all callable region arguments or
172+
/// null if there are none.
173+
::mlir::ArrayAttr getCallableArgAttrs() {
174+
return getArgAttrs().value_or(nullptr);
175+
}
176+
177+
/// Returns the result attributes for all callable region results or
178+
/// null if there are none.
179+
::mlir::ArrayAttr getCallableResAttrs() {
180+
return getResAttrs().value_or(nullptr);
181+
}
182+
171183
//===------------------------------------------------------------------===//
172184
// FunctionOpInterface Methods
173185
//===------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,6 +1583,10 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
15831583
/// Returns the result types of this function.
15841584
ArrayRef<Type> getResultTypes() { return getFunctionType().getReturnTypes(); }
15851585

1586+
//===------------------------------------------------------------------===//
1587+
// CallableOpInterface
1588+
//===------------------------------------------------------------------===//
1589+
15861590
/// Returns the callable region, which is the function body. If the function
15871591
/// is external, returns null.
15881592
Region *getCallableRegion();
@@ -1596,6 +1600,17 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
15961600
return getFunctionType().getReturnTypes();
15971601
}
15981602

1603+
/// Returns the argument attributes for all callable region arguments or
1604+
/// null if there are none.
1605+
::mlir::ArrayAttr getCallableArgAttrs() {
1606+
return getArgAttrs().value_or(nullptr);
1607+
}
1608+
1609+
/// Returns the result attributes for all callable region results or
1610+
/// null if there are none.
1611+
::mlir::ArrayAttr getCallableResAttrs() {
1612+
return getResAttrs().value_or(nullptr);
1613+
}
15991614
}];
16001615

16011616
let hasCustomAssemblyFormat = 1;

mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
7373
/// executed.
7474
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
7575

76+
/// Returns the argument attributes for all callable region arguments or
77+
/// null if there are none.
78+
::mlir::ArrayAttr getCallableArgAttrs() {
79+
return getArgAttrs().value_or(nullptr);
80+
}
81+
82+
/// Returns the result attributes for all callable region results or
83+
/// null if there are none.
84+
::mlir::ArrayAttr getCallableResAttrs() {
85+
return getResAttrs().value_or(nullptr);
86+
}
87+
7688
//===------------------------------------------------------------------===//
7789
// FunctionOpInterface Methods
7890
//===------------------------------------------------------------------===//
@@ -422,6 +434,18 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
422434
/// executed.
423435
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
424436

437+
/// Returns the argument attributes for all callable region arguments or
438+
/// null if there are none.
439+
::mlir::ArrayAttr getCallableArgAttrs() {
440+
return getArgAttrs().value_or(nullptr);
441+
}
442+
443+
/// Returns the result attributes for all callable region results or
444+
/// null if there are none.
445+
::mlir::ArrayAttr getCallableResAttrs() {
446+
return getResAttrs().value_or(nullptr);
447+
}
448+
425449
//===------------------------------------------------------------------===//
426450
// FunctionOpInterface Methods
427451
//===------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,18 @@ def Shape_FuncOp : Shape_Op<"func",
11491149
return getFunctionType().getResults();
11501150
}
11511151

1152+
/// Returns the argument attributes for all callable region arguments or
1153+
/// null if there are none.
1154+
::mlir::ArrayAttr getCallableArgAttrs() {
1155+
return getArgAttrs().value_or(nullptr);
1156+
}
1157+
1158+
/// Returns the result attributes for all callable region results or
1159+
/// null if there are none.
1160+
::mlir::ArrayAttr getCallableResAttrs() {
1161+
return getResAttrs().value_or(nullptr);
1162+
}
1163+
11521164
//===------------------------------------------------------------------===//
11531165
// FunctionOpInterface Methods
11541166
//===------------------------------------------------------------------===//

mlir/include/mlir/Interfaces/CallInterfaces.td

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -84,28 +84,18 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
8484
}],
8585
"::llvm::ArrayRef<::mlir::Type>", "getCallableResults"
8686
>,
87-
InterfaceMethod<
88-
/*desc=*/[{
87+
InterfaceMethod<[{
8988
Returns the argument attributes for all callable region arguments or
9089
null if there are none.
9190
}],
92-
/*retType=*/"::mlir::ArrayAttr",
93-
/*methodName=*/"getCallableArgAttrs",
94-
/*args=*/(ins),
95-
/*methodBody=*/"",
96-
/*defaultImplementation=*/[{ return {}; }]
91+
"::mlir::ArrayAttr", "getCallableArgAttrs"
9792
>,
98-
InterfaceMethod<
99-
/*desc=*/[{
93+
InterfaceMethod<[{
10094
Returns the result attributes for all callable region results or null
10195
if there are none.
10296
}],
103-
/*retType=*/"::mlir::ArrayAttr",
104-
/*methodName=*/"getCallableResAttrs",
105-
/*args=*/(ins),
106-
/*methodBody=*/"",
107-
/*defaultImplementation=*/[{ return {}; }]
108-
>,
97+
"::mlir::ArrayAttr", "getCallableResAttrs"
98+
>
10999
];
110100
}
111101

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2469,6 +2469,16 @@ ArrayRef<Type> spirv::FuncOp::getCallableResults() {
24692469
return getFunctionType().getResults();
24702470
}
24712471

2472+
// CallableOpInterface
2473+
::mlir::ArrayAttr spirv::FuncOp::getCallableArgAttrs() {
2474+
return getArgAttrs().value_or(nullptr);
2475+
}
2476+
2477+
// CallableOpInterface
2478+
::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() {
2479+
return getResAttrs().value_or(nullptr);
2480+
}
2481+
24722482
//===----------------------------------------------------------------------===//
24732483
// spirv.FunctionCall
24742484
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/InliningUtils.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,6 @@ static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
165165
CallOpInterface call,
166166
CallableOpInterface callable,
167167
IRMapping &mapper) {
168-
if (!call || !callable)
169-
return;
170-
171168
// Unpack the argument attributes if there are any.
172169
SmallVector<DictionaryAttr> argAttrs(
173170
callable.getCallableRegion()->getNumArguments(),
@@ -195,9 +192,6 @@ static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
195192
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
196193
CallOpInterface call, CallableOpInterface callable,
197194
ValueRange results) {
198-
if (!call || !callable)
199-
return;
200-
201195
// Unpack the result attributes if there are any.
202196
SmallVector<DictionaryAttr> resAttrs(results.size(),
203197
builder.getDictionaryAttr({}));
@@ -260,7 +254,8 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
260254
// Run the argument attribute handler before inlining the callable region.
261255
OpBuilder builder(inlineBlock, inlinePoint);
262256
auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
263-
handleArgumentImpl(interface, builder, call, callable, mapper);
257+
if (call && callable)
258+
handleArgumentImpl(interface, builder, call, callable, mapper);
264259

265260
// Check to see if the region is being cloned, or moved inline. In either
266261
// case, move the new blocks after the 'insertBlock' to improve IR
@@ -298,8 +293,9 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
298293
// Run the result attribute handler on the terminator operands.
299294
Operation *firstBlockTerminator = firstNewBlock->getTerminator();
300295
builder.setInsertionPoint(firstBlockTerminator);
301-
handleResultImpl(interface, builder, call, callable,
302-
firstBlockTerminator->getOperands());
296+
if (call && callable)
297+
handleResultImpl(interface, builder, call, callable,
298+
firstBlockTerminator->getOperands());
303299

304300
// Have the interface handle the terminator of this block.
305301
interface.handleTerminator(firstBlockTerminator,
@@ -321,8 +317,9 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
321317

322318
// Run the result attribute handler on the post insertion block arguments.
323319
builder.setInsertionPointToStart(postInsertBlock);
324-
handleResultImpl(interface, builder, call, callable,
325-
postInsertBlock->getArguments());
320+
if (call && callable)
321+
handleResultImpl(interface, builder, call, callable,
322+
postInsertBlock->getArguments());
326323

327324
/// Handle the terminators for each of the new blocks.
328325
for (auto &newBlock : newBlocks)

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
492492
::llvm::ArrayRef<::mlir::Type> getCallableResults() {
493493
return getType().cast<::mlir::FunctionType>().getResults();
494494
}
495+
::mlir::ArrayAttr getCallableArgAttrs() {
496+
return nullptr;
497+
}
498+
::mlir::ArrayAttr getCallableResAttrs() {
499+
return nullptr;
500+
}
495501
}];
496502
}
497503

0 commit comments

Comments
 (0)