Skip to content

Commit 5020e49

Browse files
committed
Add getters for multi dim loop variables in LoopLikeOpInterface
1 parent 68f4e46 commit 5020e49

File tree

6 files changed

+97
-107
lines changed

6 files changed

+97
-107
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

+2-2
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for",
118118
[AttrSizedOperandSegments, AutomaticAllocationScope,
119119
ImplicitAffineTerminator, ConditionallySpeculatable,
120120
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
121-
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
122-
"getSingleUpperBound", "getYieldedValuesMutable",
121+
["getInductionVars", "getMixedLowerBound", "getMixedStep",
122+
"getMixedUpperBound", "getYieldedValuesMutable",
123123
"replaceWithAdditionalYields"]>,
124124
DeclareOpInterfaceMethods<RegionBranchOpInterface,
125125
["getEntrySuccessorOperands"]>]> {

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

+6-31
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
136136
def ForOp : SCF_Op<"for",
137137
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
138138
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
139-
"getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
140-
"getSingleUpperBound", "getYieldedValuesMutable",
139+
"getInductionVars", "getMixedLowerBound", "getMixedStep",
140+
"getMixedUpperBound", "getYieldedValuesMutable",
141141
"promoteIfSingleIteration", "replaceWithAdditionalYields",
142142
"yieldTiledValuesAndReplace"]>,
143143
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [
301301
AttrSizedOperandSegments,
302302
AutomaticAllocationScope,
303303
DeclareOpInterfaceMethods<LoopLikeOpInterface,
304-
["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar",
305-
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
304+
["getInitsMutable", "getRegionIterArgs", "getInductionVars",
305+
"getMixedLowerBound", "getMixedUpperBound", "getMixedStep",
306306
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
307307
RecursiveMemoryEffects,
308308
SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,24 +510,6 @@ def ForallOp : SCF_Op<"forall", [
510510
];
511511

512512
let extraClassDeclaration = [{
513-
// Get lower bounds as OpFoldResult.
514-
SmallVector<OpFoldResult> getMixedLowerBound() {
515-
Builder b(getOperation()->getContext());
516-
return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
517-
}
518-
519-
// Get upper bounds as OpFoldResult.
520-
SmallVector<OpFoldResult> getMixedUpperBound() {
521-
Builder b(getOperation()->getContext());
522-
return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
523-
}
524-
525-
// Get steps as OpFoldResult.
526-
SmallVector<OpFoldResult> getMixedStep() {
527-
Builder b(getOperation()->getContext());
528-
return getMixedValues(getStaticStep(), getDynamicStep(), b);
529-
}
530-
531513
/// Get lower bounds as values.
532514
SmallVector<Value> getLowerBound(OpBuilder &b) {
533515
return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound());
@@ -584,10 +566,6 @@ def ForallOp : SCF_Op<"forall", [
584566
getNumDynamicControlOperands() + getRank());
585567
}
586568

587-
::mlir::ValueRange getInductionVars() {
588-
return getBody()->getArguments().take_front(getRank());
589-
}
590-
591569
::mlir::Value getInductionVar(int64_t idx) {
592570
return getInductionVars()[idx];
593571
}
@@ -765,8 +743,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
765743
def ParallelOp : SCF_Op<"parallel",
766744
[AutomaticAllocationScope,
767745
AttrSizedOperandSegments,
768-
DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getSingleInductionVar",
769-
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
746+
DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInductionVars",
747+
"getMixedLowerBound", "getMixedUpperBound", "getMixedStep"]>,
770748
RecursiveMemoryEffects,
771749
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
772750
SingleBlockImplicitTerminator<"scf::ReduceOp">,
@@ -846,9 +824,6 @@ def ParallelOp : SCF_Op<"parallel",
846824
];
847825

848826
let extraClassDeclaration = [{
849-
ValueRange getInductionVars() {
850-
return getBody()->getArguments();
851-
}
852827
unsigned getNumLoops() { return getStep().size(); }
853828
unsigned getNumReductions() { return getInitVals().size(); }
854829
}];

mlir/include/mlir/Interfaces/LoopLikeInterface.td

+45-20
Original file line numberDiff line numberDiff line change
@@ -93,51 +93,47 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
9393
}]
9494
>,
9595
InterfaceMethod<[{
96-
If there is a single induction variable return it, otherwise return
97-
std::nullopt.
96+
Return all induction variables.
9897
}],
99-
/*retTy=*/"::std::optional<::mlir::Value>",
100-
/*methodName=*/"getSingleInductionVar",
98+
/*retTy=*/"::mlir::ValueRange",
99+
/*methodName=*/"getInductionVars",
101100
/*args=*/(ins),
102101
/*methodBody=*/"",
103102
/*defaultImplementation=*/[{
104-
return std::nullopt;
103+
return {};
105104
}]
106105
>,
107106
InterfaceMethod<[{
108-
Return the single lower bound value or attribute if it exists, otherwise
109-
return std::nullopt.
107+
Return all lower bounds.
110108
}],
111-
/*retTy=*/"::std::optional<::mlir::OpFoldResult>",
112-
/*methodName=*/"getSingleLowerBound",
109+
/*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
110+
/*methodName=*/"getMixedLowerBound",
113111
/*args=*/(ins),
114112
/*methodBody=*/"",
115113
/*defaultImplementation=*/[{
116-
return std::nullopt;
114+
return {};
117115
}]
118116
>,
119117
InterfaceMethod<[{
120-
Return the single step value or attribute if it exists, otherwise
121-
return std::nullopt.
118+
Return all steps.
122119
}],
123-
/*retTy=*/"::std::optional<::mlir::OpFoldResult>",
124-
/*methodName=*/"getSingleStep",
120+
/*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
121+
/*methodName=*/"getMixedStep",
125122
/*args=*/(ins),
126123
/*methodBody=*/"",
127124
/*defaultImplementation=*/[{
128-
return std::nullopt;
125+
return {};
129126
}]
130127
>,
131128
InterfaceMethod<[{
132-
Return the single upper bound value or attribute if it exists, otherwise
133-
return std::nullopt.
129+
Return all upper bounds.
134130
}],
135-
/*retTy=*/"::std::optional<::mlir::OpFoldResult>",
136-
/*methodName=*/"getSingleUpperBound",
131+
/*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
132+
/*methodName=*/"getMixedUpperBound",
137133
/*args=*/(ins),
138134
/*methodBody=*/"",
139135
/*defaultImplementation=*/[{
140-
return std::nullopt;
136+
return {};
141137
}]
142138
>,
143139
InterfaceMethod<[{
@@ -235,6 +231,35 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
235231
}];
236232

237233
let extraSharedClassDeclaration = [{
234+
/// If there is a single induction variable return it, otherwise return
235+
/// std::nullopt.
236+
::std::optional<::mlir::Value> getSingleInductionVar() {
237+
if (this->getInductionVars().size() == 1)
238+
return this->getInductionVars()[0];
239+
return std::nullopt;
240+
}
241+
/// Return the single lower bound value or attribute if it exists, otherwise
242+
/// return std::nullopt.
243+
::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
244+
if (this->getMixedLowerBound().size() == 1)
245+
return this->getMixedLowerBound()[0];
246+
return std::nullopt;
247+
}
248+
/// Return the single step value or attribute if it exists, otherwise
249+
/// return std::nullopt.
250+
::std::optional<::mlir::OpFoldResult> getSingleStep() {
251+
if (this->getMixedStep().size() == 1)
252+
return this->getMixedStep()[0];
253+
return std::nullopt;
254+
}
255+
/// Return the single upper bound value or attribute if it exists, otherwise
256+
/// return std::nullopt.
257+
::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
258+
if (this->getMixedUpperBound().size() == 1)
259+
return this->getMixedUpperBound()[0];
260+
return std::nullopt;
261+
}
262+
238263
/// Append the specified additional "init" operands: replace this loop with
239264
/// a new loop that has the additional init operands. The loop body of this
240265
/// loop is moved over to the new loop.

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

+9-11
Original file line numberDiff line numberDiff line change
@@ -2454,27 +2454,25 @@ bool AffineForOp::matchingBoundOperandList() {
24542454

24552455
SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
24562456

2457-
std::optional<Value> AffineForOp::getSingleInductionVar() {
2458-
return getInductionVar();
2459-
}
2457+
ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; }
24602458

2461-
std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
2459+
SmallVector<OpFoldResult> AffineForOp::getMixedLowerBound() {
24622460
if (!hasConstantLowerBound())
2463-
return std::nullopt;
2461+
return {};
24642462
OpBuilder b(getContext());
2465-
return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
2463+
return {OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
24662464
}
24672465

2468-
std::optional<OpFoldResult> AffineForOp::getSingleStep() {
2466+
SmallVector<OpFoldResult> AffineForOp::getMixedStep() {
24692467
OpBuilder b(getContext());
2470-
return OpFoldResult(b.getI64IntegerAttr(getStepAsInt()));
2468+
return {OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
24712469
}
24722470

2473-
std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
2471+
SmallVector<OpFoldResult> AffineForOp::getMixedUpperBound() {
24742472
if (!hasConstantUpperBound())
2475-
return std::nullopt;
2473+
return {};
24762474
OpBuilder b(getContext());
2477-
return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
2475+
return {OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
24782476
}
24792477

24802478
FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(

mlir/lib/Dialect/SCF/IR/SCF.cpp

+27-43
Original file line numberDiff line numberDiff line change
@@ -378,20 +378,18 @@ LogicalResult ForOp::verifyRegions() {
378378
return success();
379379
}
380380

381-
std::optional<Value> ForOp::getSingleInductionVar() {
382-
return getInductionVar();
383-
}
381+
ValueRange ForOp::getInductionVars() { return {getInductionVar()}; }
384382

385-
std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
386-
return OpFoldResult(getLowerBound());
383+
SmallVector<OpFoldResult> ForOp::getMixedLowerBound() {
384+
return {OpFoldResult(getLowerBound())};
387385
}
388386

389-
std::optional<OpFoldResult> ForOp::getSingleStep() {
390-
return OpFoldResult(getStep());
387+
SmallVector<OpFoldResult> ForOp::getMixedStep() {
388+
return {OpFoldResult(getStep())};
391389
}
392390

393-
std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
394-
return OpFoldResult(getUpperBound());
391+
SmallVector<OpFoldResult> ForOp::getMixedUpperBound() {
392+
return {OpFoldResult(getUpperBound())};
395393
}
396394

397395
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
@@ -1428,28 +1426,26 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
14281426
return storeOps;
14291427
}
14301428

1431-
std::optional<Value> ForallOp::getSingleInductionVar() {
1432-
if (getRank() != 1)
1433-
return std::nullopt;
1434-
return getInductionVar(0);
1429+
ValueRange ForallOp::getInductionVars() {
1430+
return getBody()->getArguments().take_front(getRank());
14351431
}
14361432

1437-
std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
1438-
if (getRank() != 1)
1439-
return std::nullopt;
1440-
return getMixedLowerBound()[0];
1433+
// Get lower bounds as OpFoldResult.
1434+
SmallVector<OpFoldResult> ForallOp::getMixedLowerBound() {
1435+
Builder b(getOperation()->getContext());
1436+
return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
14411437
}
14421438

1443-
std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
1444-
if (getRank() != 1)
1445-
return std::nullopt;
1446-
return getMixedUpperBound()[0];
1439+
// Get upper bounds as OpFoldResult.
1440+
SmallVector<OpFoldResult> ForallOp::getMixedUpperBound() {
1441+
Builder b(getOperation()->getContext());
1442+
return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
14471443
}
14481444

1449-
std::optional<OpFoldResult> ForallOp::getSingleStep() {
1450-
if (getRank() != 1)
1451-
return std::nullopt;
1452-
return getMixedStep()[0];
1445+
// Get steps as OpFoldResult.
1446+
SmallVector<OpFoldResult> ForallOp::getMixedStep() {
1447+
Builder b(getOperation()->getContext());
1448+
return getMixedValues(getStaticStep(), getDynamicStep(), b);
14531449
}
14541450

14551451
ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
@@ -3008,29 +3004,17 @@ void ParallelOp::print(OpAsmPrinter &p) {
30083004

30093005
SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
30103006

3011-
std::optional<Value> ParallelOp::getSingleInductionVar() {
3012-
if (getNumLoops() != 1)
3013-
return std::nullopt;
3014-
return getBody()->getArgument(0);
3015-
}
3007+
ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); }
30163008

3017-
std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
3018-
if (getNumLoops() != 1)
3019-
return std::nullopt;
3020-
return getLowerBound()[0];
3009+
SmallVector<OpFoldResult> ParallelOp::getMixedLowerBound() {
3010+
return getLowerBound();
30213011
}
30223012

3023-
std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
3024-
if (getNumLoops() != 1)
3025-
return std::nullopt;
3026-
return getUpperBound()[0];
3013+
SmallVector<OpFoldResult> ParallelOp::getMixedUpperBound() {
3014+
return getUpperBound();
30273015
}
30283016

3029-
std::optional<OpFoldResult> ParallelOp::getSingleStep() {
3030-
if (getNumLoops() != 1)
3031-
return std::nullopt;
3032-
return getStep()[0];
3033-
}
3017+
SmallVector<OpFoldResult> ParallelOp::getMixedStep() { return getStep(); }
30343018

30353019
ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
30363020
auto ivArg = llvm::dyn_cast<BlockArgument>(val);

mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ class SCFLoopLikeTest : public ::testing::Test {
3636
std::optional<OpFoldResult> maybeIndVar =
3737
loopLikeOp.getSingleInductionVar();
3838
EXPECT_TRUE(maybeIndVar.has_value());
39+
EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u);
40+
EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
41+
EXPECT_EQ(loopLikeOp.getMixedStep().size(), 1u);
42+
EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
3943
}
4044

4145
void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -48,6 +52,10 @@ class SCFLoopLikeTest : public ::testing::Test {
4852
std::optional<OpFoldResult> maybeIndVar =
4953
loopLikeOp.getSingleInductionVar();
5054
EXPECT_FALSE(maybeIndVar.has_value());
55+
EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u);
56+
EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
57+
EXPECT_EQ(loopLikeOp.getMixedStep().size(), 2u);
58+
EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
5159
}
5260

5361
MLIRContext context;

0 commit comments

Comments
 (0)