Skip to content

Commit d6cc35f

Browse files
author
Peiming Liu
authored
Reapply "[mlir][sparse] implement lowering rules for IterateOp." (#95836)
1 parent 5b04b6f commit d6cc35f

File tree

4 files changed

+224
-17
lines changed

4 files changed

+224
-17
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
3434
return success();
3535
}
3636

37+
static std::optional<LogicalResult>
38+
convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
39+
// The actually Iterator Values (that are updated every iteration).
40+
auto idxTp = IndexType::get(itTp.getContext());
41+
// TODO: handle batch dimension.
42+
assert(itTp.getEncoding().getBatchLvlRank() == 0);
43+
if (!itTp.isUnique()) {
44+
// Segment high for non-unique iterator.
45+
fields.push_back(idxTp);
46+
}
47+
fields.push_back(idxTp);
48+
return success();
49+
}
50+
3751
namespace {
3852

3953
/// Sparse codegen rule for number of entries operator.
@@ -57,10 +71,114 @@ class ExtractIterSpaceConverter
5771
}
5872
};
5973

74+
class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
75+
public:
76+
using OneToNOpConversionPattern::OneToNOpConversionPattern;
77+
LogicalResult
78+
matchAndRewrite(IterateOp op, OpAdaptor adaptor,
79+
OneToNPatternRewriter &rewriter) const override {
80+
if (!op.getCrdUsedLvls().empty())
81+
return rewriter.notifyMatchFailure(
82+
op, "non-empty coordinates list not implemented.");
83+
84+
Location loc = op.getLoc();
85+
86+
auto iterSpace = SparseIterationSpace::fromValues(
87+
op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
88+
89+
std::unique_ptr<SparseIterator> it =
90+
iterSpace.extractIterator(rewriter, loc);
91+
92+
if (it->iteratableByFor()) {
93+
auto [lo, hi] = it->genForCond(rewriter, loc);
94+
Value step = constantIndex(rewriter, loc, 1);
95+
SmallVector<Value> ivs;
96+
for (ValueRange inits : adaptor.getInitArgs())
97+
llvm::append_range(ivs, inits);
98+
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
99+
100+
Block *loopBody = op.getBody();
101+
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
102+
if (failed(typeConverter->convertSignatureArgs(
103+
loopBody->getArgumentTypes(), bodyTypeMapping)))
104+
return failure();
105+
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
106+
107+
rewriter.eraseBlock(forOp.getBody());
108+
Region &dstRegion = forOp.getRegion();
109+
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
110+
111+
auto yieldOp =
112+
llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
113+
114+
rewriter.setInsertionPointToEnd(forOp.getBody());
115+
// replace sparse_tensor.yield with scf.yield.
116+
rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
117+
rewriter.eraseOp(yieldOp);
118+
119+
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
120+
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
121+
} else {
122+
SmallVector<Value> ivs;
123+
llvm::append_range(ivs, it->getCursor());
124+
for (ValueRange inits : adaptor.getInitArgs())
125+
llvm::append_range(ivs, inits);
126+
127+
assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
128+
129+
TypeRange types = ValueRange(ivs).getTypes();
130+
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
131+
SmallVector<Location> l(types.size(), op.getIterator().getLoc());
132+
133+
// Generates loop conditions.
134+
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
135+
rewriter.setInsertionPointToStart(before);
136+
ValueRange bArgs = before->getArguments();
137+
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
138+
assert(remArgs.size() == adaptor.getInitArgs().size());
139+
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
140+
141+
// Generates loop body.
142+
Block *loopBody = op.getBody();
143+
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
144+
if (failed(typeConverter->convertSignatureArgs(
145+
loopBody->getArgumentTypes(), bodyTypeMapping)))
146+
return failure();
147+
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
148+
149+
Region &dstRegion = whileOp.getAfter();
150+
// TODO: handle uses of coordinate!
151+
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
152+
ValueRange aArgs = whileOp.getAfterArguments();
153+
auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
154+
whileOp.getAfterBody()->getTerminator());
155+
156+
rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
157+
158+
aArgs = it->linkNewScope(aArgs);
159+
ValueRange nx = it->forward(rewriter, loc);
160+
SmallVector<Value> yields;
161+
llvm::append_range(yields, nx);
162+
llvm::append_range(yields, yieldOp.getResults());
163+
164+
// replace sparse_tensor.yield with scf.yield.
165+
rewriter.eraseOp(yieldOp);
166+
rewriter.create<scf::YieldOp>(loc, yields);
167+
168+
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
169+
rewriter.replaceOp(
170+
op, whileOp.getResults().drop_front(it->getCursor().size()),
171+
resultMapping);
172+
}
173+
return success();
174+
}
175+
};
176+
60177
} // namespace
61178

62179
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
63180
addConversion([](Type type) { return type; });
181+
addConversion(convertIteratorType);
64182
addConversion(convertIterSpaceType);
65183

66184
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
@@ -74,5 +192,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
74192

75193
void mlir::populateLowerSparseIterationToSCFPatterns(
76194
TypeConverter &converter, RewritePatternSet &patterns) {
77-
patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
195+
patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
196+
converter, patterns.getContext());
78197
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,13 @@ class TrivialIterator : public ConcreteIterator {
331331
TrivialIterator(const SparseTensorLevel &stl)
332332
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
333333

334+
TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
335+
Value posLo, Value posHi)
336+
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
337+
posHi(posHi) {
338+
seek(posLo);
339+
}
340+
334341
std::string getDebugInterfacePrefix() const override {
335342
return std::string("trivial<") + stl.toString() + ">";
336343
}
@@ -420,6 +427,14 @@ class DedupIterator : public ConcreteIterator {
420427
: ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
421428
assert(!stl.isUnique());
422429
}
430+
431+
DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
432+
Value posLo, Value posHi)
433+
: ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
434+
assert(!stl.isUnique());
435+
seek({posLo, genSegmentHigh(b, l, posLo)});
436+
}
437+
423438
// For LLVM-style RTTI.
424439
static bool classof(const SparseIterator *from) {
425440
return from->kind == IterKind::kDedup;
@@ -1532,6 +1547,11 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
15321547
return space;
15331548
}
15341549

1550+
std::unique_ptr<SparseIterator>
1551+
SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
1552+
return makeSimpleIterator(b, l, *this);
1553+
}
1554+
15351555
//===----------------------------------------------------------------------===//
15361556
// SparseIterator factory functions.
15371557
//===----------------------------------------------------------------------===//
@@ -1590,6 +1610,26 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
15901610
return std::make_pair(std::move(stl), std::move(it));
15911611
}
15921612

1613+
std::unique_ptr<SparseIterator>
1614+
sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
1615+
const SparseIterationSpace &iterSpace) {
1616+
// assert(iterSpace.getSpaceDim() == 1);
1617+
std::unique_ptr<SparseIterator> ret;
1618+
if (!iterSpace.isUnique()) {
1619+
// We always dedupliate the non-unique level, but we should optimize it away
1620+
// if possible.
1621+
ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
1622+
iterSpace.getBoundLo(),
1623+
iterSpace.getBoundHi());
1624+
} else {
1625+
ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
1626+
iterSpace.getBoundLo(),
1627+
iterSpace.getBoundHi());
1628+
}
1629+
ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
1630+
return ret;
1631+
}
1632+
15931633
std::unique_ptr<SparseIterator>
15941634
sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
15951635
SparseEmitStrategy strategy) {

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ class SparseIterationSpace {
132132
Value getBoundLo() const { return bound.first; }
133133
Value getBoundHi() const { return bound.second; }
134134

135+
// Extract an iterator to iterate over the sparse iteration space.
136+
std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
137+
Location l) const;
138+
135139
private:
136140
SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
137141
std::pair<Value, Value> bound;
@@ -192,6 +196,13 @@ class SparseIterator {
192196
crd = nullptr;
193197
}
194198

199+
// Reconstructs a iteration space directly from the provided ValueRange.
200+
static std::unique_ptr<SparseIterator>
201+
fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
202+
203+
// The inverse operation of `fromValues`.
204+
SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
205+
195206
//
196207
// Iterator properties.
197208
//
@@ -345,12 +356,21 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
345356
unsigned tid,
346357
Level lvl);
347358

348-
/// Helper function to create a TensorLevel object from given `tensor`.
359+
/// Helper function to create a TensorLevel object from given ValueRange.
349360
std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
350361
ValueRange buffers,
351362
unsigned tid, Level l);
352-
/// Helper function to create a simple SparseIterator object that iterates
353-
/// over the SparseTensorLevel.
363+
364+
/// Helper function to create a simple SparseIterator object that iterate
365+
/// over the entire iteration space.
366+
std::unique_ptr<SparseIterator>
367+
makeSimpleIterator(OpBuilder &b, Location l,
368+
const SparseIterationSpace &iterSpace);
369+
370+
/// Helper function to create a simple SparseIterator object that iterate
371+
/// over the sparse tensor level.
372+
/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
373+
/// feature complete.
354374
std::unique_ptr<SparseIterator> makeSimpleIterator(
355375
const SparseTensorLevel &stl,
356376
SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
2+
// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED
23

34
#COO = #sparse_tensor.encoding<{
45
map = (i, j) -> (
@@ -7,17 +8,44 @@
78
)
89
}>
910

10-
// CHECK-LABEL: func.func @sparse_1D_space(
11-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
12-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
13-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
14-
// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
15-
// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
16-
// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
17-
// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
18-
// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
19-
// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
20-
func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
21-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
22-
return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
11+
// CHECK-LABEL: @sparse_iteration_to_scf
12+
// // deduplication
13+
// CHECK: scf.while {{.*}} {
14+
// CHECK: } do {
15+
// CHECK: }
16+
// CHECK: scf.while {{.*}} {
17+
// CHECK: } do {
18+
// // actual computation
19+
// CHECK: scf.for {{.*}} {
20+
// CHECK: arith.addi
21+
// CHECK: }
22+
// // deduplication
23+
// CHECK: scf.while {{.*}} {
24+
// CHECK: } do {
25+
// CHECK: }
26+
// CHECK: scf.yield
27+
// CHECK: }
28+
// CHECK: return
29+
30+
// COLLAPSED-LABEL: @sparse_iteration_to_scf
31+
// COLLAPSED: %[[RET:.*]] = scf.for {{.*}} {
32+
// COLLAPSED: %[[VAL:.*]] = arith.addi
33+
// COLLAPSED: scf.yield %[[VAL]] : index
34+
// COLLAPSED: }
35+
// COLLAPSED: return %[[RET]] : index
36+
func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index {
37+
%i = arith.constant 0 : index
38+
%c1 = arith.constant 1 : index
39+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
40+
: tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
41+
%r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
42+
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
43+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1>
44+
%r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
45+
%k = arith.addi %inner, %c1 : index
46+
sparse_tensor.yield %k : index
47+
}
48+
sparse_tensor.yield %r2 : index
49+
}
50+
return %r1 : index
2351
}

0 commit comments

Comments
 (0)