@@ -34,6 +34,20 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
34
34
return success ();
35
35
}
36
36
37
+ static std::optional<LogicalResult>
38
+ convertIteratorType (IteratorType itTp, SmallVectorImpl<Type> &fields) {
39
+ // The actually Iterator Values (that are updated every iteration).
40
+ auto idxTp = IndexType::get (itTp.getContext ());
41
+ // TODO: handle batch dimension.
42
+ assert (itTp.getEncoding ().getBatchLvlRank () == 0 );
43
+ if (!itTp.isUnique ()) {
44
+ // Segment high for non-unique iterator.
45
+ fields.push_back (idxTp);
46
+ }
47
+ fields.push_back (idxTp);
48
+ return success ();
49
+ }
50
+
37
51
namespace {
38
52
39
53
// / Sparse codegen rule for number of entries operator.
@@ -57,10 +71,114 @@ class ExtractIterSpaceConverter
57
71
}
58
72
};
59
73
74
+ class SparseIterateOpConverter : public OneToNOpConversionPattern <IterateOp> {
75
+ public:
76
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
77
+ LogicalResult
78
+ matchAndRewrite (IterateOp op, OpAdaptor adaptor,
79
+ OneToNPatternRewriter &rewriter) const override {
80
+ if (!op.getCrdUsedLvls ().empty ())
81
+ return rewriter.notifyMatchFailure (
82
+ op, " non-empty coordinates list not implemented." );
83
+
84
+ Location loc = op.getLoc ();
85
+
86
+ auto iterSpace = SparseIterationSpace::fromValues (
87
+ op.getIterSpace ().getType (), adaptor.getIterSpace (), 0 );
88
+
89
+ std::unique_ptr<SparseIterator> it =
90
+ iterSpace.extractIterator (rewriter, loc);
91
+
92
+ if (it->iteratableByFor ()) {
93
+ auto [lo, hi] = it->genForCond (rewriter, loc);
94
+ Value step = constantIndex (rewriter, loc, 1 );
95
+ SmallVector<Value> ivs;
96
+ for (ValueRange inits : adaptor.getInitArgs ())
97
+ llvm::append_range (ivs, inits);
98
+ scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, ivs);
99
+
100
+ Block *loopBody = op.getBody ();
101
+ OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
102
+ if (failed (typeConverter->convertSignatureArgs (
103
+ loopBody->getArgumentTypes (), bodyTypeMapping)))
104
+ return failure ();
105
+ rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
106
+
107
+ rewriter.eraseBlock (forOp.getBody ());
108
+ Region &dstRegion = forOp.getRegion ();
109
+ rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
110
+
111
+ auto yieldOp =
112
+ llvm::cast<sparse_tensor::YieldOp>(forOp.getBody ()->getTerminator ());
113
+
114
+ rewriter.setInsertionPointToEnd (forOp.getBody ());
115
+ // replace sparse_tensor.yield with scf.yield.
116
+ rewriter.create <scf::YieldOp>(loc, yieldOp.getResults ());
117
+ rewriter.eraseOp (yieldOp);
118
+
119
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
120
+ rewriter.replaceOp (op, forOp.getResults (), resultMapping);
121
+ } else {
122
+ SmallVector<Value> ivs;
123
+ llvm::append_range (ivs, it->getCursor ());
124
+ for (ValueRange inits : adaptor.getInitArgs ())
125
+ llvm::append_range (ivs, inits);
126
+
127
+ assert (llvm::all_of (ivs, [](Value v) { return v != nullptr ; }));
128
+
129
+ TypeRange types = ValueRange (ivs).getTypes ();
130
+ auto whileOp = rewriter.create <scf::WhileOp>(loc, types, ivs);
131
+ SmallVector<Location> l (types.size (), op.getIterator ().getLoc ());
132
+
133
+ // Generates loop conditions.
134
+ Block *before = rewriter.createBlock (&whileOp.getBefore (), {}, types, l);
135
+ rewriter.setInsertionPointToStart (before);
136
+ ValueRange bArgs = before->getArguments ();
137
+ auto [whileCond, remArgs] = it->genWhileCond (rewriter, loc, bArgs);
138
+ assert (remArgs.size () == adaptor.getInitArgs ().size ());
139
+ rewriter.create <scf::ConditionOp>(loc, whileCond, before->getArguments ());
140
+
141
+ // Generates loop body.
142
+ Block *loopBody = op.getBody ();
143
+ OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
144
+ if (failed (typeConverter->convertSignatureArgs (
145
+ loopBody->getArgumentTypes (), bodyTypeMapping)))
146
+ return failure ();
147
+ rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
148
+
149
+ Region &dstRegion = whileOp.getAfter ();
150
+ // TODO: handle uses of coordinate!
151
+ rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
152
+ ValueRange aArgs = whileOp.getAfterArguments ();
153
+ auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
154
+ whileOp.getAfterBody ()->getTerminator ());
155
+
156
+ rewriter.setInsertionPointToEnd (whileOp.getAfterBody ());
157
+
158
+ aArgs = it->linkNewScope (aArgs);
159
+ ValueRange nx = it->forward (rewriter, loc);
160
+ SmallVector<Value> yields;
161
+ llvm::append_range (yields, nx);
162
+ llvm::append_range (yields, yieldOp.getResults ());
163
+
164
+ // replace sparse_tensor.yield with scf.yield.
165
+ rewriter.eraseOp (yieldOp);
166
+ rewriter.create <scf::YieldOp>(loc, yields);
167
+
168
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
169
+ rewriter.replaceOp (
170
+ op, whileOp.getResults ().drop_front (it->getCursor ().size ()),
171
+ resultMapping);
172
+ }
173
+ return success ();
174
+ }
175
+ };
176
+
60
177
} // namespace
61
178
62
179
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter () {
63
180
addConversion ([](Type type) { return type; });
181
+ addConversion (convertIteratorType);
64
182
addConversion (convertIterSpaceType);
65
183
66
184
addSourceMaterialization ([](OpBuilder &builder, IterSpaceType spTp,
@@ -74,5 +192,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
74
192
75
193
void mlir::populateLowerSparseIterationToSCFPatterns (
76
194
TypeConverter &converter, RewritePatternSet &patterns) {
77
- patterns.add <ExtractIterSpaceConverter>(converter, patterns.getContext ());
195
+ patterns.add <ExtractIterSpaceConverter, SparseIterateOpConverter>(
196
+ converter, patterns.getContext ());
78
197
}
0 commit comments