Skip to content

[mlir][sparse] unify lib/codegen rewriting rules for sparse tensor concatenation. #68057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,11 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
for (uint64_t l = 0; l < lvlRank; ++l) {
const auto crd = lvlCoords[l];
const auto cur = lvlCursor[l];
if (crd > cur || (crd == cur && !isUniqueLvl(l)))
if (crd > cur || (crd == cur && !isUniqueLvl(l)) ||
(crd < cur && !isOrderedLvl(l))) {
return l;
}

if (crd < cur) {
assert(false && "non-lexicographic insertion");
return -1u;
Expand Down Expand Up @@ -900,8 +903,7 @@ class SparseTensorEnumeratorBase {

//===----------------------------------------------------------------------===//
template <typename P, typename C, typename V>
class SparseTensorEnumerator final
: public SparseTensorEnumeratorBase<V> {
class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
using Base = SparseTensorEnumeratorBase<V>;
using StorageImpl = SparseTensorStorage<P, C, V>;

Expand Down
316 changes: 10 additions & 306 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,38 +162,6 @@ static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
return out;
}

/// Populates the given sizes array for concatenation from type (for static
/// sizes) and from an already-converted opaque pointer source (for dynamic
/// sizes).
static void concatDimSizesFromInputs(OpBuilder &builder, Location loc,
SparseTensorType dstTp, ValueRange srcs,
Dimension dim,
SmallVectorImpl<Value> &dimSizes) {
assert(dim < dstTp.getDimRank() && "Dimension is out of bounds");
dimSizes.clear();

// We first fills the sizes from an input tensor, and then
// compute the size of the concatenation dimension if necessary.
const auto srcTp = getSparseTensorType(srcs[0]);
if (srcTp.hasEncoding())
// Reuses sizes from an arbitrary input tensor is fine.
fillDimSizes(builder, loc, srcTp, srcs[0], dimSizes);
else
sizesFromSrc(builder, dimSizes, loc, srcs[0]);

if (const auto sz = dstTp.getStaticDimSize(dim)) {
// Faithfully take the static size.
dimSizes[dim] = constantIndex(builder, loc, *sz);
} else {
// Else, dynamically compute the size.
for (const auto src : srcs.drop_front()) {
const auto srcTp = getSparseTensorType(src);
Value srcSz = createOrFoldDimCall(builder, loc, srcTp, src, dim);
dimSizes[dim] = builder.create<arith::AddIOp>(loc, dimSizes[dim], srcSz);
}
}
}

/// Generates an uninitialized buffer of the given size and type,
/// but returns it as type `memref<? x $tp>` (rather than as type
/// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
Expand Down Expand Up @@ -467,107 +435,6 @@ static bool canUseDirectConversion(ArrayRef<DimLevelType> dimTypes) {
return true;
}

// Generates a while loop that iterates over the COO list extracted
// from `t`, using `bodyBuilder` to build the loop body.
// while (elem = coo->getNext()) {
// bodyBuilder
// }
// TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to
// reduce code repetition!
// TODO: rename to `genSparseIterationLoop`?
static void genSparseCOOIterationLoop(
ConversionPatternRewriter &rewriter, Location loc, Value t,
SparseTensorType stt,
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilder) {
assert(stt.hasEncoding() &&
"Generating Sparse Tensor COO Loop on a Dense Tensor!");
const Dimension dimRank = stt.getDimRank();
const Type elemTp = stt.getElementType();

// Start an iterator over the tensor (in coordinate order).
const auto noPerm = stt.withoutDimToLvl();
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, t);
Value iter = NewCallParams(rewriter, loc)
.genBuffers(noPerm, dimSizes)
.genNewCall(Action::kToIterator, t);

// Construct a while loop over the iterator.
const Type iTp = rewriter.getIndexType();
Value srcDimCoords = genAlloca(rewriter, loc, dimRank, iTp);
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
const SmallVector<Value> noArgs;
const SmallVector<Type> noTypes;
auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
rewriter.setInsertionPointToEnd(before);
Value cond = genGetNextCall(rewriter, loc, iter, srcDimCoords, elemPtr);
rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
rewriter.setInsertionPointToStart(after);

const bool hasDenseDim =
llvm::any_of(stt.getEncoding().getLvlTypes(), isDenseDLT);
if (hasDenseDim) {
Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr);
Value isZero = genIsNonzero(rewriter, loc, elemV);
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, isZero, /*else*/ false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
}
// Callback here to build loop body.
bodyBuilder(rewriter, loc, srcDimCoords, elemPtr);

// Exit the scope from the IfOp.
if (hasDenseDim)
rewriter.setInsertionPointToEnd(after);

rewriter.create<scf::YieldOp>(loc);
// Finish generating loop.
rewriter.setInsertionPointAfter(whileOp);

// Free memory for iterator.
genDelIteratorCall(rewriter, loc, elemTp, iter);
}

// Generate loop that iterates over a dense tensor.
// for i1 in dim1
// ..
// for ik in dimk
// val = a[i1,..,ik]
// if val != 0
// bodyBuilder(v, [i1, ..., ik])
// TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to
// reduce code repetition!
static void genDenseTensorIterationLoop(
ConversionPatternRewriter &rewriter, Location loc, Value t,
SparseTensorType stt,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
assert(!stt.hasEncoding() &&
"Generating Dense Tensor Loop on a Sparse Tensor!");

const Dimension dimRank = stt.getDimRank();
Value zero = constantIndex(rewriter, loc, 0);
Value one = constantIndex(rewriter, loc, 1);

SmallVector<Value> lo;
SmallVector<Value> hi;
SmallVector<Value> st;

// Fill out loop iteration information.
for (Dimension d = 0; d < dimRank; d++) {
lo.push_back(zero);
hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, t, d));
st.push_back(one);
}

scf::buildLoopNest(rewriter, loc, lo, hi, st, {},
[&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange args) -> scf::ValueVector {
// Invoke callback to build the body of the loop.
bodyBuilder(builder, loc, ivs);
return {};
});
}

//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1198,168 +1065,6 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
}
};

/// Sparse conversion rule for the concatenate operator.
class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ConcatenateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// The conversion works as follow:
// (1). When output is sparse and not all dims are dense, and mix of inputs:
// a_sparse = concat (b_dense, c_sparse, ....)
// =>
// coo_for_a = newSparseCOO(shapeOf(a))
// for i, j, k // dense input
// coo->add(adjustForOffset(i,j,k), b[i,j,k])
//
// for elem in sparse_input
// coo->add(adjustForOffset(elem.coords), elem.value)
// ...
// a = newSparseTensor(coo_for_a)
// return a
//
// (2). When output is dense or annotated all dense, and mix of inputs:
// a_dense = concat (b_dense, c_sparse, ....)
// =>
// a = malloc(shapeOf(a)) or newSparseAllDense(shapeOf(a))
// for i, j, k // dense input
// a[ adjustForOffset(i,j,k) ] = b[i,j,k]
//
// for elem in sparse_input
// a[ adjustForOffset(elem.coords) ] = elem.value
// return a
Location loc = op.getLoc();
const auto dstTp = getSparseTensorType(op);
const auto dstEnc = dstTp.getEncoding();
const Type elemTp = dstTp.getElementType();
const Dimension concatDim = op.getDimension();
const Dimension dimRank = dstTp.getDimRank();

Value dst; // destination tensor
Value dstDimToLvl; // destination tensor permutation (if sparse out)
// A pointer to the value being inserted (if dense => sparse)
Value elemPtr;
// Memory that holds the dim-coords for destination tensor (if sparse out)
Value dstDimCoords;
// The offset applied to the dimension to be concated (starting from 0)
Value offset = constantIndex(rewriter, loc, 0);

SmallVector<Value> dimSizes;
concatDimSizesFromInputs(rewriter, loc, dstTp, op.getInputs(), concatDim,
dimSizes);

NewCallParams params(rewriter, loc);
const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense();
Value dstTensor;
if (dstTp.hasEncoding()) {
// Start a new COO or an initialized annotated all dense sparse tensor.
dst = params.genBuffers(dstTp, dimSizes)
.genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO);
dstDimCoords = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType());
if (allDense) {
dstTensor = dst;
// Get the values buffer for the sparse tensor and reshape it to the
// corresponding dense tensor shape.
dst = genValuesCall(rewriter, loc,
MemRefType::get({ShapedType::kDynamic}, elemTp),
{dst});
// Pass the `dstDimCoords` buffer for `reshapeValuesToLevels`
// to reuse for storing level-sizes (yes, "level-sizes").
// This is safe to do because `dstTp` is a dense-tensor type,
// and therefore lvlRank == dimRank.
dst = reshapeValuesToLevels(rewriter, loc, dstEnc, dimSizes, dst,
dstDimCoords);
} else {
dstDimToLvl = params.getDimToLvl();
elemPtr = genAllocaScalar(rewriter, loc, elemTp);
}
} else {
// TODO: Dense buffers should be allocated/deallocated via the callback
// in BufferizationOptions.
dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes);
}
const Level lvlRank = dstTp.getLvlRank();
const auto dcvs2lcvs = [&](ValueRange dcvs) -> SmallVector<Value> {
SmallVector<Value> lcvs;
lcvs.reserve(lvlRank);
for (Level l = 0; l < lvlRank; l++)
// FIXME: `toOrigDim` is deprecated
lcvs.push_back(dcvs[toOrigDim(dstEnc, l)]);
return lcvs;
};
for (const auto &it : llvm::zip(op.getInputs(), adaptor.getInputs())) {
Value orignalOp = std::get<0>(it); // Input (with encoding) from Op
Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor
const auto srcTp = getSparseTensorType(orignalOp);
if (srcTp.hasEncoding()) {
genSparseCOOIterationLoop(
rewriter, loc, adaptedOp, srcTp,
[&](OpBuilder &builder, Location loc, Value dimCoords,
Value elemPtr) -> void {
const auto dcvs =
loadAll(builder, loc, dimRank, dimCoords, concatDim, offset);
if (dstTp.hasEncoding() && !allDense) {
// Case: sparse => sparse, except for annotated all dense.
storeAll(builder, loc, dstDimCoords, dcvs);
genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstDimCoords,
dstDimToLvl);
} else {
// Case: sparse => dense, or annotated all dense.
const auto lcvs = allDense ? dcvs2lcvs(dcvs) : dcvs;
insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, lcvs);
}
});
} else {
genDenseTensorIterationLoop(
rewriter, loc, adaptedOp, srcTp,
[&](OpBuilder &builder, Location loc, ValueRange dcvs) -> void {
if (dstTp.hasEncoding() && !allDense) {
// Case: dense => sparse, except for annotated all dense.
assert(dcvs.size() == static_cast<size_t>(dimRank));
storeAll(builder, loc, dstDimCoords, dcvs, concatDim, offset);
Value val = genValueForDense(builder, loc, adaptedOp, dcvs);
builder.create<memref::StoreOp>(loc, val, elemPtr);
genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstDimCoords,
dstDimToLvl);
} else {
// Case: dense => dense, or annotated all dense.
Value val = genValueForDense(builder, loc, adaptedOp, dcvs);
// Despite the name, this isn't actually level-cvs until
// after the `dcvs2lcvs` call.
SmallVector<Value> lcvs(dcvs);
// Apply offset.
lcvs[concatDim] =
builder.create<arith::AddIOp>(loc, lcvs[concatDim], offset);
if (allDense)
lcvs = dcvs2lcvs(lcvs);
builder.create<memref::StoreOp>(loc, val, dst, lcvs);
}
});
}
// Accumulate offset.
// TODO: avoid calling sparseDimSize multiple times by caching the result!
Value curDim =
createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim);
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
}
if (!dstTp.hasEncoding()) {
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(
op, dstTp.getRankedTensorType(), dst);
} else if (allDense) {
rewriter.replaceOp(op, dstTensor);
} else {
// In sparse output case, the destination holds the COO.
Value coo = dst;
dst = params.genNewCall(Action::kFromCOO, coo);
// Release resources.
genDelCOOCall(rewriter, loc, elemTp, coo);
rewriter.replaceOp(op, dst);
}
return success();
}
};

/// Sparse conversion rule for the output operator.
class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
public:
Expand Down Expand Up @@ -1434,17 +1139,16 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
void mlir::populateSparseTensorConversionPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
const SparseTensorConversionOptions &options) {
patterns
.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
SparseCastConverter, SparseTensorNewConverter,
SparseTensorConcatConverter, SparseTensorAllocConverter,
SparseTensorEmptyConverter, SparseTensorDeallocConverter,
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorOutConverter, SparseTensorAssembleConverter>(
typeConverter, patterns.getContext());
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
SparseCastConverter, SparseTensorNewConverter,
SparseTensorAllocConverter, SparseTensorEmptyConverter,
SparseTensorDeallocConverter, SparseTensorToPositionsConverter,
SparseTensorToCoordinatesConverter,
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorOutConverter, SparseTensorAssembleConverter>(
typeConverter, patterns.getContext());
patterns.add<SparseTensorConvertConverter>(typeConverter,
patterns.getContext(), options);
}
Original file line number Diff line number Diff line change
Expand Up @@ -1474,17 +1474,17 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableRT,
bool enableForeach,
bool enableConvert) {
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
TensorReshapeRewriter>(patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());

// TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT) {
patterns.add<ConcatenateRewriter, NewRewriter, OutRewriter>(
patterns.getContext());
patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
if (enableConvert)
patterns.add<ConvertRewriter>(patterns.getContext());
}
Expand Down
Loading