Skip to content

Commit 327d627

Browse files
authored
[mlir] share argument attributes interface between calls and callables (#123176)
This patch shares core interface methods dealing with argument and result attributes from CallableOpInterface with the CallOpInterface and makes them mandatory to gives more consistent guarantees about concrete operations using these interfaces. This allows adding argument attributes on call like operations, which is sometimes required to get proper ABI, like with llvm.call (and llvm.invoke). The patch adds optional `arg_attrs` and `res_attrs` attributes to operations using these interfaces that did not have that already. They can then re-use the common "rich function signature" printing/parsing helpers if they want (for the LLVM dialect, this is done in the next patch). Part of RFC: https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107
1 parent 8f025f2 commit 327d627

File tree

32 files changed

+452
-256
lines changed

32 files changed

+452
-256
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
207207
I32:$block_z,
208208
Optional<I32>:$bytes,
209209
Optional<I32>:$stream,
210-
Variadic<AnyType>:$args
210+
Variadic<AnyType>:$args,
211+
OptionalAttr<DictArrayAttr>:$arg_attrs,
212+
OptionalAttr<DictArrayAttr>:$res_attrs
211213
);
212214

213215
let assemblyFormat = [{

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2432,6 +2432,8 @@ def fir_CallOp : fir_Op<"call",
24322432
let arguments = (ins
24332433
OptionalAttr<SymbolRefAttr>:$callee,
24342434
Variadic<AnyType>:$args,
2435+
OptionalAttr<DictArrayAttr>:$arg_attrs,
2436+
OptionalAttr<DictArrayAttr>:$res_attrs,
24352437
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
24362438
DefaultValuedAttr<Arith_FastMathAttr,
24372439
"::mlir::arith::FastMathFlags::none">:$fastmath
@@ -2518,6 +2520,8 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
25182520
fir_ClassType:$object,
25192521
Variadic<AnyType>:$args,
25202522
OptionalAttr<I32Attr>:$pass_arg_pos,
2523+
OptionalAttr<DictArrayAttr>:$arg_attrs,
2524+
OptionalAttr<DictArrayAttr>:$res_attrs,
25212525
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs
25222526
);
25232527

flang/lib/Lower/ConvertCall.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ Fortran::lower::genCallOpAndResult(
594594

595595
builder.create<cuf::KernelLaunchOp>(
596596
loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z,
597-
block_x, block_y, block_z, bytes, stream, operands);
597+
block_x, block_y, block_z, bytes, stream, operands,
598+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
598599
callNumResults = 0;
599600
} else if (caller.requireDispatchCall()) {
600601
// Procedure call requiring a dynamic dispatch. Call is created with
@@ -621,7 +622,8 @@ Fortran::lower::genCallOpAndResult(
621622
dispatch = builder.create<fir::DispatchOp>(
622623
loc, funcType.getResults(), builder.getStringAttr(procName),
623624
caller.getInputs()[*passArg], operands,
624-
builder.getI32IntegerAttr(*passArg), procAttrs);
625+
builder.getI32IntegerAttr(*passArg), /*arg_attrs=*/nullptr,
626+
/*res_attrs=*/nullptr, procAttrs);
625627
} else {
626628
// NOPASS
627629
const Fortran::evaluate::Component *component =
@@ -636,15 +638,17 @@ Fortran::lower::genCallOpAndResult(
636638
passObject = builder.create<fir::LoadOp>(loc, passObject);
637639
dispatch = builder.create<fir::DispatchOp>(
638640
loc, funcType.getResults(), builder.getStringAttr(procName),
639-
passObject, operands, nullptr, procAttrs);
641+
passObject, operands, nullptr, /*arg_attrs=*/nullptr,
642+
/*res_attrs=*/nullptr, procAttrs);
640643
}
641644
callNumResults = dispatch.getNumResults();
642645
if (callNumResults != 0)
643646
callResult = dispatch.getResult(0);
644647
} else {
645648
// Standard procedure call with fir.call.
646649
auto call = builder.create<fir::CallOp>(
647-
loc, funcType.getResults(), funcSymbolAttr, operands, procAttrs);
650+
loc, funcType.getResults(), funcSymbolAttr, operands,
651+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs);
648652

649653
callNumResults = call.getNumResults();
650654
if (callNumResults != 0)

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
518518
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
519519

520520
llvm::SmallVector<mlir::Value, 1> newCallResults;
521+
// TODO propagate/update call argument and result attributes.
521522
if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
522523
auto newCall = rewriter->create<A>(
523524
loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
@@ -557,6 +558,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
557558
loc, newResTys, rewriter->getStringAttr(callOp.getMethod()),
558559
callOp.getOperands()[0], newOpers,
559560
rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift),
561+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
560562
callOp.getProcedureAttrsAttr());
561563
if (wrap)
562564
newCallResults.push_back((*wrap)(dispatchOp.getOperation()));

flang/lib/Optimizer/Transforms/AbstractResult.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
147147
newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
148148

149149
Op newOp;
150+
// TODO: propagate argument and result attributes (need to be shifted).
150151
// fir::CallOp specific handling.
151152
if constexpr (std::is_same_v<Op, fir::CallOp>) {
152153
if (op.getCallee()) {
@@ -189,9 +190,11 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
189190
if (op.getPassArgPos())
190191
passArgPos =
191192
rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
193+
// TODO: propagate argument and result attributes (need to be shifted).
192194
newOp = rewriter.create<fir::DispatchOp>(
193195
loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
194196
op.getOperands()[0], newOperands, passArgPos,
197+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
195198
op.getProcedureAttrsAttr());
196199
}
197200

flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,9 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
205205
// Make the call.
206206
llvm::SmallVector<mlir::Value> args{funcPtr};
207207
args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
208-
rewriter.replaceOpWithNewOp<fir::CallOp>(dispatch, resTypes, nullptr, args,
209-
dispatch.getProcedureAttrsAttr());
208+
rewriter.replaceOpWithNewOp<fir::CallOp>(
209+
dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(),
210+
dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr());
210211
return mlir::success();
211212
}
212213

mlir/docs/Interfaces.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,10 +753,15 @@ interface section goes as follows:
753753
- (`C++ class` -- `ODS class`(if applicable))
754754

755755
##### CallInterfaces
756-
757756
* `CallOpInterface` - Used to represent operations like 'call'
758757
- `CallInterfaceCallable getCallableForCallee()`
759758
- `void setCalleeFromCallable(CallInterfaceCallable)`
759+
- `ArrayAttr getArgAttrsAttr()`
760+
- `ArrayAttr getResAttrsAttr()`
761+
- `void setArgAttrsAttr(ArrayAttr)`
762+
- `void setResAttrsAttr(ArrayAttr)`
763+
- `Attribute removeArgAttrsAttr()`
764+
- `Attribute removeResAttrsAttr()`
760765
* `CallableOpInterface` - Used to represent the target callee of call.
761766
- `Region * getCallableRegion()`
762767
- `ArrayRef<Type> getArgumentTypes()`

mlir/examples/toy/Ch4/include/toy/Ops.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,12 @@ def GenericCallOp : Toy_Op<"generic_call",
215215

216216
// The generic call operation takes a symbol reference attribute as the
217217
// callee, and inputs for the call.
218-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
218+
let arguments = (ins
219+
FlatSymbolRefAttr:$callee,
220+
Variadic<F64Tensor>:$inputs,
221+
OptionalAttr<DictArrayAttr>:$arg_attrs,
222+
OptionalAttr<DictArrayAttr>:$res_attrs
223+
);
219224

220225
// The generic call operation returns a single value of TensorType.
221226
let results = (outs F64Tensor);

mlir/examples/toy/Ch5/include/toy/Ops.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call",
214214

215215
// The generic call operation takes a symbol reference attribute as the
216216
// callee, and inputs for the call.
217-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
217+
let arguments = (ins
218+
FlatSymbolRefAttr:$callee,
219+
Variadic<F64Tensor>:$inputs,
220+
OptionalAttr<DictArrayAttr>:$arg_attrs,
221+
OptionalAttr<DictArrayAttr>:$res_attrs
222+
);
218223

219224
// The generic call operation returns a single value of TensorType.
220225
let results = (outs F64Tensor);

mlir/examples/toy/Ch6/include/toy/Ops.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call",
214214

215215
// The generic call operation takes a symbol reference attribute as the
216216
// callee, and inputs for the call.
217-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
217+
let arguments = (ins
218+
FlatSymbolRefAttr:$callee,
219+
Variadic<F64Tensor>:$inputs,
220+
OptionalAttr<DictArrayAttr>:$arg_attrs,
221+
OptionalAttr<DictArrayAttr>:$res_attrs
222+
);
218223

219224
// The generic call operation returns a single value of TensorType.
220225
let results = (outs F64Tensor);

mlir/examples/toy/Ch7/include/toy/Ops.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,12 @@ def GenericCallOp : Toy_Op<"generic_call",
237237

238238
// The generic call operation takes a symbol reference attribute as the
239239
// callee, and inputs for the call.
240-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs);
240+
let arguments = (ins
241+
FlatSymbolRefAttr:$callee,
242+
Variadic<Toy_Type>:$inputs,
243+
OptionalAttr<DictArrayAttr>:$arg_attrs,
244+
OptionalAttr<DictArrayAttr>:$res_attrs
245+
);
241246

242247
// The generic call operation returns a single value of TensorType or
243248
// StructType.
@@ -250,7 +255,8 @@ def GenericCallOp : Toy_Op<"generic_call",
250255

251256
// Add custom build methods for the generic call operation.
252257
let builders = [
253-
OpBuilder<(ins "StringRef":$callee, "ArrayRef<Value>":$arguments)>
258+
OpBuilder<(ins "Type":$result_type, "StringRef":$callee,
259+
"ArrayRef<Value>":$arguments)>
254260
];
255261
}
256262

mlir/examples/toy/Ch7/mlir/Dialect.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) {
350350
//===----------------------------------------------------------------------===//
351351

352352
void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
353-
StringRef callee, ArrayRef<mlir::Value> arguments) {
354-
// Generic call always returns an unranked Tensor initially.
355-
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
353+
mlir::Type resultType, StringRef callee,
354+
ArrayRef<mlir::Value> arguments) {
355+
state.addTypes(resultType);
356356
state.addOperands(arguments);
357357
state.addAttribute("callee",
358358
mlir::SymbolRefAttr::get(builder.getContext(), callee));

mlir/examples/toy/Ch7/mlir/MLIRGen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,7 @@ class MLIRGenImpl {
535535
}
536536
mlir::toy::FuncOp calledFunc = calledFuncIt->second;
537537
return builder.create<GenericCallOp>(
538-
location, calledFunc.getFunctionType().getResult(0),
539-
mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);
538+
location, calledFunc.getFunctionType().getResult(0), callee, operands);
540539
}
541540

542541
/// Emit a print expression. It emits specific operations for two builtins:

mlir/include/mlir/Dialect/Async/IR/AsyncOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,13 @@ def Async_CallOp : Async_Op<"call",
208208
```
209209
}];
210210

211-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
211+
let arguments = (ins
212+
FlatSymbolRefAttr:$callee,
213+
Variadic<AnyType>:$operands,
214+
OptionalAttr<DictArrayAttr>:$arg_attrs,
215+
OptionalAttr<DictArrayAttr>:$res_attrs
216+
);
217+
212218
let results = (outs Variadic<Async_AnyValueOrTokenType>);
213219

214220
let builders = [

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,13 @@ def EmitC_CallOp : EmitC_Op<"call",
551551
%2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
552552
```
553553
}];
554-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<EmitCType>:$operands);
554+
let arguments = (ins
555+
FlatSymbolRefAttr:$callee,
556+
Variadic<EmitCType>:$operands,
557+
OptionalAttr<DictArrayAttr>:$arg_attrs,
558+
OptionalAttr<DictArrayAttr>:$res_attrs
559+
);
560+
555561
let results = (outs Variadic<EmitCType>);
556562

557563
let builders = [

mlir/include/mlir/Dialect/Func/IR/FuncOps.td

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ def CallOp : Func_Op<"call",
4949
```
5050
}];
5151

52-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands,
53-
UnitAttr:$no_inline);
52+
let arguments = (ins
53+
FlatSymbolRefAttr:$callee,
54+
Variadic<AnyType>:$operands,
55+
OptionalAttr<DictArrayAttr>:$arg_attrs,
56+
OptionalAttr<DictArrayAttr>:$res_attrs,
57+
UnitAttr:$no_inline
58+
);
59+
5460
let results = (outs Variadic<AnyType>);
5561

5662
let builders = [
@@ -73,6 +79,18 @@ def CallOp : Func_Op<"call",
7379
CArg<"ValueRange", "{}">:$operands), [{
7480
build($_builder, $_state, StringAttr::get($_builder.getContext(), callee),
7581
results, operands);
82+
}]>,
83+
OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
84+
CArg<"ValueRange", "{}">:$operands), [{
85+
build($_builder, $_state, callee, results, operands);
86+
}]>,
87+
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
88+
CArg<"ValueRange", "{}">:$operands), [{
89+
build($_builder, $_state, callee, results, operands);
90+
}]>,
91+
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
92+
CArg<"ValueRange", "{}">:$operands), [{
93+
build($_builder, $_state, callee, results, operands);
7694
}]>];
7795

7896
let extraClassDeclaration = [{
@@ -136,8 +154,13 @@ def CallIndirectOp : Func_Op<"call_indirect", [
136154
```
137155
}];
138156

139-
let arguments = (ins FunctionType:$callee,
140-
Variadic<AnyType>:$callee_operands);
157+
let arguments = (ins
158+
FunctionType:$callee,
159+
Variadic<AnyType>:$callee_operands,
160+
OptionalAttr<DictArrayAttr>:$arg_attrs,
161+
OptionalAttr<DictArrayAttr>:$res_attrs
162+
);
163+
141164
let results = (outs Variadic<AnyType>:$results);
142165

143166
let builders = [

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
633633
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
634634
OptionalAttr<FlatSymbolRefAttr>:$callee,
635635
Variadic<LLVM_Type>:$callee_operands,
636+
OptionalAttr<DictArrayAttr>:$arg_attrs,
637+
OptionalAttr<DictArrayAttr>:$res_attrs,
636638
Variadic<LLVM_Type>:$normalDestOperands,
637639
Variadic<LLVM_Type>:$unwindDestOperands,
638640
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
@@ -755,7 +757,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
755757
VariadicOfVariadic<LLVM_Type,
756758
"op_bundle_sizes">:$op_bundle_operands,
757759
DenseI32ArrayAttr:$op_bundle_sizes,
758-
OptionalAttr<ArrayAttr>:$op_bundle_tags);
760+
OptionalAttr<ArrayAttr>:$op_bundle_tags,
761+
OptionalAttr<DictArrayAttr>:$arg_attrs,
762+
OptionalAttr<DictArrayAttr>:$res_attrs);
759763
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
760764
let arguments = !con(args, aliasAttrs);
761765
let results = (outs Optional<LLVM_Type>:$result);

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,24 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
214214

215215
let arguments = (ins
216216
FlatSymbolRefAttr:$callee,
217-
Variadic<SPIRV_Type>:$arguments
217+
Variadic<SPIRV_Type>:$arguments,
218+
OptionalAttr<DictArrayAttr>:$arg_attrs,
219+
OptionalAttr<DictArrayAttr>:$res_attrs
218220
);
219221

220222
let results = (outs
221223
Optional<SPIRV_Type>:$return_value
222224
);
223225

226+
let builders = [
227+
OpBuilder<(ins "Type":$returnType, "FlatSymbolRefAttr":$callee,
228+
"ValueRange":$arguments),
229+
[{
230+
build($_builder, $_state, returnType, callee, arguments,
231+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
232+
}]>
233+
];
234+
224235
let autogenSerialization = 0;
225236

226237
let assemblyFormat = [{

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,9 @@ def IncludeOp : TransformDialectOp<"include",
886886

887887
let arguments = (ins SymbolRefAttr:$target,
888888
FailurePropagationMode:$failure_propagation_mode,
889-
Variadic<Transform_AnyHandleOrParamType>:$operands);
889+
Variadic<Transform_AnyHandleOrParamType>:$operands,
890+
OptionalAttr<DictArrayAttr>:$arg_attrs,
891+
OptionalAttr<DictArrayAttr>:$res_attrs);
890892
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
891893

892894
let assemblyFormat =

0 commit comments

Comments
 (0)