Skip to content

Commit c3b01b4

Browse files
authored
[mlir][sparse] unify lib/codegen rewriting rules for sparse tensor concatenation. (llvm#68057)
1 parent 09b30f4 commit c3b01b4

File tree

5 files changed

+437
-1184
lines changed

5 files changed

+437
-1184
lines changed

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,8 +783,11 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
783783
for (uint64_t l = 0; l < lvlRank; ++l) {
784784
const auto crd = lvlCoords[l];
785785
const auto cur = lvlCursor[l];
786-
if (crd > cur || (crd == cur && !isUniqueLvl(l)))
786+
if (crd > cur || (crd == cur && !isUniqueLvl(l)) ||
787+
(crd < cur && !isOrderedLvl(l))) {
787788
return l;
789+
}
790+
788791
if (crd < cur) {
789792
assert(false && "non-lexicographic insertion");
790793
return -1u;
@@ -900,8 +903,7 @@ class SparseTensorEnumeratorBase {
900903

901904
//===----------------------------------------------------------------------===//
902905
template <typename P, typename C, typename V>
903-
class SparseTensorEnumerator final
904-
: public SparseTensorEnumeratorBase<V> {
906+
class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
905907
using Base = SparseTensorEnumeratorBase<V>;
906908
using StorageImpl = SparseTensorStorage<P, C, V>;
907909

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 10 additions & 306 deletions
Original file line numberDiff line numberDiff line change
@@ -162,38 +162,6 @@ static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
162162
return out;
163163
}
164164

165-
/// Populates the given sizes array for concatenation from type (for static
166-
/// sizes) and from an already-converted opaque pointer source (for dynamic
167-
/// sizes).
168-
static void concatDimSizesFromInputs(OpBuilder &builder, Location loc,
169-
SparseTensorType dstTp, ValueRange srcs,
170-
Dimension dim,
171-
SmallVectorImpl<Value> &dimSizes) {
172-
assert(dim < dstTp.getDimRank() && "Dimension is out of bounds");
173-
dimSizes.clear();
174-
175-
// We first fills the sizes from an input tensor, and then
176-
// compute the size of the concatenation dimension if necessary.
177-
const auto srcTp = getSparseTensorType(srcs[0]);
178-
if (srcTp.hasEncoding())
179-
// Reuses sizes from an arbitrary input tensor is fine.
180-
fillDimSizes(builder, loc, srcTp, srcs[0], dimSizes);
181-
else
182-
sizesFromSrc(builder, dimSizes, loc, srcs[0]);
183-
184-
if (const auto sz = dstTp.getStaticDimSize(dim)) {
185-
// Faithfully take the static size.
186-
dimSizes[dim] = constantIndex(builder, loc, *sz);
187-
} else {
188-
// Else, dynamically compute the size.
189-
for (const auto src : srcs.drop_front()) {
190-
const auto srcTp = getSparseTensorType(src);
191-
Value srcSz = createOrFoldDimCall(builder, loc, srcTp, src, dim);
192-
dimSizes[dim] = builder.create<arith::AddIOp>(loc, dimSizes[dim], srcSz);
193-
}
194-
}
195-
}
196-
197165
/// Generates an uninitialized buffer of the given size and type,
198166
/// but returns it as type `memref<? x $tp>` (rather than as type
199167
/// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
@@ -467,107 +435,6 @@ static bool canUseDirectConversion(ArrayRef<DimLevelType> dimTypes) {
467435
return true;
468436
}
469437

470-
// Generates a while loop that iterates over the COO list extracted
471-
// from `t`, using `bodyBuilder` to build the loop body.
472-
// while (elem = coo->getNext()) {
473-
// bodyBuilder
474-
// }
475-
// TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to
476-
// reduce code repetition!
477-
// TODO: rename to `genSparseIterationLoop`?
478-
static void genSparseCOOIterationLoop(
479-
ConversionPatternRewriter &rewriter, Location loc, Value t,
480-
SparseTensorType stt,
481-
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilder) {
482-
assert(stt.hasEncoding() &&
483-
"Generating Sparse Tensor COO Loop on a Dense Tensor!");
484-
const Dimension dimRank = stt.getDimRank();
485-
const Type elemTp = stt.getElementType();
486-
487-
// Start an iterator over the tensor (in coordinate order).
488-
const auto noPerm = stt.withoutDimToLvl();
489-
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, t);
490-
Value iter = NewCallParams(rewriter, loc)
491-
.genBuffers(noPerm, dimSizes)
492-
.genNewCall(Action::kToIterator, t);
493-
494-
// Construct a while loop over the iterator.
495-
const Type iTp = rewriter.getIndexType();
496-
Value srcDimCoords = genAlloca(rewriter, loc, dimRank, iTp);
497-
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
498-
const SmallVector<Value> noArgs;
499-
const SmallVector<Type> noTypes;
500-
auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
501-
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
502-
rewriter.setInsertionPointToEnd(before);
503-
Value cond = genGetNextCall(rewriter, loc, iter, srcDimCoords, elemPtr);
504-
rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
505-
Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
506-
rewriter.setInsertionPointToStart(after);
507-
508-
const bool hasDenseDim =
509-
llvm::any_of(stt.getEncoding().getLvlTypes(), isDenseDLT);
510-
if (hasDenseDim) {
511-
Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr);
512-
Value isZero = genIsNonzero(rewriter, loc, elemV);
513-
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, isZero, /*else*/ false);
514-
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
515-
}
516-
// Callback here to build loop body.
517-
bodyBuilder(rewriter, loc, srcDimCoords, elemPtr);
518-
519-
// Exit the scope from the IfOp.
520-
if (hasDenseDim)
521-
rewriter.setInsertionPointToEnd(after);
522-
523-
rewriter.create<scf::YieldOp>(loc);
524-
// Finish generating loop.
525-
rewriter.setInsertionPointAfter(whileOp);
526-
527-
// Free memory for iterator.
528-
genDelIteratorCall(rewriter, loc, elemTp, iter);
529-
}
530-
531-
// Generate loop that iterates over a dense tensor.
532-
// for i1 in dim1
533-
// ..
534-
// for ik in dimk
535-
// val = a[i1,..,ik]
536-
// if val != 0
537-
// bodyBuilder(v, [i1, ..., ik])
538-
// TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to
539-
// reduce code repetition!
540-
static void genDenseTensorIterationLoop(
541-
ConversionPatternRewriter &rewriter, Location loc, Value t,
542-
SparseTensorType stt,
543-
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
544-
assert(!stt.hasEncoding() &&
545-
"Generating Dense Tensor Loop on a Sparse Tensor!");
546-
547-
const Dimension dimRank = stt.getDimRank();
548-
Value zero = constantIndex(rewriter, loc, 0);
549-
Value one = constantIndex(rewriter, loc, 1);
550-
551-
SmallVector<Value> lo;
552-
SmallVector<Value> hi;
553-
SmallVector<Value> st;
554-
555-
// Fill out loop iteration information.
556-
for (Dimension d = 0; d < dimRank; d++) {
557-
lo.push_back(zero);
558-
hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, t, d));
559-
st.push_back(one);
560-
}
561-
562-
scf::buildLoopNest(rewriter, loc, lo, hi, st, {},
563-
[&](OpBuilder &builder, Location loc, ValueRange ivs,
564-
ValueRange args) -> scf::ValueVector {
565-
// Invoke callback to build the body of the loop.
566-
bodyBuilder(builder, loc, ivs);
567-
return {};
568-
});
569-
}
570-
571438
//===----------------------------------------------------------------------===//
572439
// Conversion rules.
573440
//===----------------------------------------------------------------------===//
@@ -1198,168 +1065,6 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
11981065
}
11991066
};
12001067

1201-
/// Sparse conversion rule for the concatenate operator.
1202-
class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
1203-
public:
1204-
using OpConversionPattern::OpConversionPattern;
1205-
LogicalResult
1206-
matchAndRewrite(ConcatenateOp op, OpAdaptor adaptor,
1207-
ConversionPatternRewriter &rewriter) const override {
1208-
// The conversion works as follow:
1209-
// (1). When output is sparse and not all dims are dense, and mix of inputs:
1210-
// a_sparse = concat (b_dense, c_sparse, ....)
1211-
// =>
1212-
// coo_for_a = newSparseCOO(shapeOf(a))
1213-
// for i, j, k // dense input
1214-
// coo->add(adjustForOffset(i,j,k), b[i,j,k])
1215-
//
1216-
// for elem in sparse_input
1217-
// coo->add(adjustForOffset(elem.coords), elem.value)
1218-
// ...
1219-
// a = newSparseTensor(coo_for_a)
1220-
// return a
1221-
//
1222-
// (2). When output is dense or annotated all dense, and mix of inputs:
1223-
// a_dense = concat (b_dense, c_sparse, ....)
1224-
// =>
1225-
// a = malloc(shapeOf(a)) or newSparseAllDense(shapeOf(a))
1226-
// for i, j, k // dense input
1227-
// a[ adjustForOffset(i,j,k) ] = b[i,j,k]
1228-
//
1229-
// for elem in sparse_input
1230-
// a[ adjustForOffset(elem.coords) ] = elem.value
1231-
// return a
1232-
Location loc = op.getLoc();
1233-
const auto dstTp = getSparseTensorType(op);
1234-
const auto dstEnc = dstTp.getEncoding();
1235-
const Type elemTp = dstTp.getElementType();
1236-
const Dimension concatDim = op.getDimension();
1237-
const Dimension dimRank = dstTp.getDimRank();
1238-
1239-
Value dst; // destination tensor
1240-
Value dstDimToLvl; // destination tensor permutation (if sparse out)
1241-
// A pointer to the value being inserted (if dense => sparse)
1242-
Value elemPtr;
1243-
// Memory that holds the dim-coords for destination tensor (if sparse out)
1244-
Value dstDimCoords;
1245-
// The offset applied to the dimension to be concated (starting from 0)
1246-
Value offset = constantIndex(rewriter, loc, 0);
1247-
1248-
SmallVector<Value> dimSizes;
1249-
concatDimSizesFromInputs(rewriter, loc, dstTp, op.getInputs(), concatDim,
1250-
dimSizes);
1251-
1252-
NewCallParams params(rewriter, loc);
1253-
const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense();
1254-
Value dstTensor;
1255-
if (dstTp.hasEncoding()) {
1256-
// Start a new COO or an initialized annotated all dense sparse tensor.
1257-
dst = params.genBuffers(dstTp, dimSizes)
1258-
.genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO);
1259-
dstDimCoords = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType());
1260-
if (allDense) {
1261-
dstTensor = dst;
1262-
// Get the values buffer for the sparse tensor and reshape it to the
1263-
// corresponding dense tensor shape.
1264-
dst = genValuesCall(rewriter, loc,
1265-
MemRefType::get({ShapedType::kDynamic}, elemTp),
1266-
{dst});
1267-
// Pass the `dstDimCoords` buffer for `reshapeValuesToLevels`
1268-
// to reuse for storing level-sizes (yes, "level-sizes").
1269-
// This is safe to do because `dstTp` is a dense-tensor type,
1270-
// and therefore lvlRank == dimRank.
1271-
dst = reshapeValuesToLevels(rewriter, loc, dstEnc, dimSizes, dst,
1272-
dstDimCoords);
1273-
} else {
1274-
dstDimToLvl = params.getDimToLvl();
1275-
elemPtr = genAllocaScalar(rewriter, loc, elemTp);
1276-
}
1277-
} else {
1278-
// TODO: Dense buffers should be allocated/deallocated via the callback
1279-
// in BufferizationOptions.
1280-
dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes);
1281-
}
1282-
const Level lvlRank = dstTp.getLvlRank();
1283-
const auto dcvs2lcvs = [&](ValueRange dcvs) -> SmallVector<Value> {
1284-
SmallVector<Value> lcvs;
1285-
lcvs.reserve(lvlRank);
1286-
for (Level l = 0; l < lvlRank; l++)
1287-
// FIXME: `toOrigDim` is deprecated
1288-
lcvs.push_back(dcvs[toOrigDim(dstEnc, l)]);
1289-
return lcvs;
1290-
};
1291-
for (const auto &it : llvm::zip(op.getInputs(), adaptor.getInputs())) {
1292-
Value orignalOp = std::get<0>(it); // Input (with encoding) from Op
1293-
Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor
1294-
const auto srcTp = getSparseTensorType(orignalOp);
1295-
if (srcTp.hasEncoding()) {
1296-
genSparseCOOIterationLoop(
1297-
rewriter, loc, adaptedOp, srcTp,
1298-
[&](OpBuilder &builder, Location loc, Value dimCoords,
1299-
Value elemPtr) -> void {
1300-
const auto dcvs =
1301-
loadAll(builder, loc, dimRank, dimCoords, concatDim, offset);
1302-
if (dstTp.hasEncoding() && !allDense) {
1303-
// Case: sparse => sparse, except for annotated all dense.
1304-
storeAll(builder, loc, dstDimCoords, dcvs);
1305-
genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstDimCoords,
1306-
dstDimToLvl);
1307-
} else {
1308-
// Case: sparse => dense, or annotated all dense.
1309-
const auto lcvs = allDense ? dcvs2lcvs(dcvs) : dcvs;
1310-
insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, lcvs);
1311-
}
1312-
});
1313-
} else {
1314-
genDenseTensorIterationLoop(
1315-
rewriter, loc, adaptedOp, srcTp,
1316-
[&](OpBuilder &builder, Location loc, ValueRange dcvs) -> void {
1317-
if (dstTp.hasEncoding() && !allDense) {
1318-
// Case: dense => sparse, except for annotated all dense.
1319-
assert(dcvs.size() == static_cast<size_t>(dimRank));
1320-
storeAll(builder, loc, dstDimCoords, dcvs, concatDim, offset);
1321-
Value val = genValueForDense(builder, loc, adaptedOp, dcvs);
1322-
builder.create<memref::StoreOp>(loc, val, elemPtr);
1323-
genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstDimCoords,
1324-
dstDimToLvl);
1325-
} else {
1326-
// Case: dense => dense, or annotated all dense.
1327-
Value val = genValueForDense(builder, loc, adaptedOp, dcvs);
1328-
// Despite the name, this isn't actually level-cvs until
1329-
// after the `dcvs2lcvs` call.
1330-
SmallVector<Value> lcvs(dcvs);
1331-
// Apply offset.
1332-
lcvs[concatDim] =
1333-
builder.create<arith::AddIOp>(loc, lcvs[concatDim], offset);
1334-
if (allDense)
1335-
lcvs = dcvs2lcvs(lcvs);
1336-
builder.create<memref::StoreOp>(loc, val, dst, lcvs);
1337-
}
1338-
});
1339-
}
1340-
// Accumulate offset.
1341-
// TODO: avoid calling sparseDimSize multiple times by caching the result!
1342-
Value curDim =
1343-
createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim);
1344-
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
1345-
}
1346-
if (!dstTp.hasEncoding()) {
1347-
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(
1348-
op, dstTp.getRankedTensorType(), dst);
1349-
} else if (allDense) {
1350-
rewriter.replaceOp(op, dstTensor);
1351-
} else {
1352-
// In sparse output case, the destination holds the COO.
1353-
Value coo = dst;
1354-
dst = params.genNewCall(Action::kFromCOO, coo);
1355-
// Release resources.
1356-
genDelCOOCall(rewriter, loc, elemTp, coo);
1357-
rewriter.replaceOp(op, dst);
1358-
}
1359-
return success();
1360-
}
1361-
};
1362-
13631068
/// Sparse conversion rule for the output operator.
13641069
class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
13651070
public:
@@ -1434,17 +1139,16 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
14341139
void mlir::populateSparseTensorConversionPatterns(
14351140
TypeConverter &typeConverter, RewritePatternSet &patterns,
14361141
const SparseTensorConversionOptions &options) {
1437-
patterns
1438-
.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
1439-
SparseCastConverter, SparseTensorNewConverter,
1440-
SparseTensorConcatConverter, SparseTensorAllocConverter,
1441-
SparseTensorEmptyConverter, SparseTensorDeallocConverter,
1442-
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
1443-
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
1444-
SparseTensorLoadConverter, SparseTensorInsertConverter,
1445-
SparseTensorExpandConverter, SparseTensorCompressConverter,
1446-
SparseTensorOutConverter, SparseTensorAssembleConverter>(
1447-
typeConverter, patterns.getContext());
1142+
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
1143+
SparseCastConverter, SparseTensorNewConverter,
1144+
SparseTensorAllocConverter, SparseTensorEmptyConverter,
1145+
SparseTensorDeallocConverter, SparseTensorToPositionsConverter,
1146+
SparseTensorToCoordinatesConverter,
1147+
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
1148+
SparseTensorLoadConverter, SparseTensorInsertConverter,
1149+
SparseTensorExpandConverter, SparseTensorCompressConverter,
1150+
SparseTensorOutConverter, SparseTensorAssembleConverter>(
1151+
typeConverter, patterns.getContext());
14481152
patterns.add<SparseTensorConvertConverter>(typeConverter,
14491153
patterns.getContext(), options);
14501154
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,17 +1474,17 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
14741474
bool enableRT,
14751475
bool enableForeach,
14761476
bool enableConvert) {
1477-
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
1477+
patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
14781478
ReshapeRewriter<tensor::CollapseShapeOp>,
14791479
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
14801480
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
14811481
TensorReshapeRewriter>(patterns.getContext());
14821482
if (enableForeach)
14831483
patterns.add<ForeachRewriter>(patterns.getContext());
1484+
14841485
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
14851486
if (!enableRT) {
1486-
patterns.add<ConcatenateRewriter, NewRewriter, OutRewriter>(
1487-
patterns.getContext());
1487+
patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
14881488
if (enableConvert)
14891489
patterns.add<ConvertRewriter>(patterns.getContext());
14901490
}

0 commit comments

Comments
 (0)