diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index e8a40b1e033dd..9e9fea76416b9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -7,11 +7,17 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" -#include "mlir/Transforms/OneToNTypeConversion.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::sparse_tensor; +/// Assert that the given value range contains a single value and return it. +static Value getSingleValue(ValueRange values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); +} + static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, SmallVectorImpl &fields) { // Position and coordinate buffer in the sparse structure. @@ -54,14 +60,17 @@ static ValueRange genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, Value loopCrd, ArrayRef> iters, - ArrayRef subCases, ArrayRef userReduc) { - if (subCases.empty()) + ArrayRef newBlocks, ArrayRef oldBlocks, + ArrayRef userReduc) { + if (newBlocks.empty()) return userReduc; // The current branch that we are handling. - Region *b = subCases.front(); + Block *newBlock = newBlocks.front(); + Block *oldBlock = oldBlocks.front(); Value casePred = constantI1(rewriter, loc, true); - I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber()); + I64BitSet caseBits = + op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber()); for (unsigned i : caseBits.bits()) { SparseIterator *it = iters[i].get(); Value pred = rewriter.create(loc, arith::CmpIPredicate::eq, @@ -80,16 +89,20 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, for (unsigned idx : caseBits.bits()) llvm::append_range(blockArgs, iters[idx]->getCursor()); + // Map the old block arguments, because the dialect conversion driver does + // not immediately perform SSA value replacements. This function is still + // seeing the old uses. IRMapping mapping; - for (auto [from, to] : - llvm::zip_equal(b->front().getArguments(), blockArgs)) { + for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) { mapping.map(from, to); } // Clone the region, we can not erase the region now because the same region // might be a subcase for multiple lattice point. - rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(), + rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(), ifOp.getThenRegion().begin(), mapping); + // Remove the block arguments, they were already replaced via `mapping`. + ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size()); // replace sparse_tensor::YieldOp -> scf::YieldOp auto spY = cast(&ifOp.getThenRegion().front().back()); @@ -101,7 +114,8 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, // Generates remaining case recursively. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters, - subCases.drop_front(), userReduc); + newBlocks.drop_front(), + oldBlocks.drop_front(), userReduc); if (!res.empty()) rewriter.create(loc, res); @@ -119,15 +133,13 @@ static ValueRange genLoopWithIterator( if (it->iteratableByFor()) { auto [lo, hi] = it->genForCond(rewriter, loc); Value step = constantIndex(rewriter, loc, 1); - scf::ForOp forOp = rewriter.create(loc, lo, hi, step, reduc); + scf::ForOp forOp = rewriter.create( + loc, lo, hi, step, reduc, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { + // Empty builder function to ensure that no terminator is created. + }); { OpBuilder::InsertionGuard guard(rewriter); - // Erase the implicit yield operation created by ForOp when there is no - // yielding values. - if (!forOp.getBody()->empty()) - rewriter.eraseOp(&forOp.getBody()->front()); - assert(forOp.getBody()->empty()); - it->linkNewScope(forOp.getInductionVar()); rewriter.setInsertionPointToStart(forOp.getBody()); SmallVector ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(), @@ -178,46 +190,47 @@ namespace { /// Sparse codegen rule for number of entries operator. class ExtractIterSpaceConverter - : public OneToNOpConversionPattern { + : public OpConversionPattern { public: - using OneToNOpConversionPattern::OneToNOpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); // Construct the iteration space. - SparseIterationSpace space(loc, rewriter, op.getTensor(), 0, + SparseIterationSpace space(loc, rewriter, + getSingleValue(adaptor.getTensor()), 0, op.getLvlRange(), adaptor.getParentIter()); SmallVector result = space.toValues(); - rewriter.replaceOp(op, result, resultMapping); + rewriter.replaceOpWithMultiple(op, {result}); return success(); } }; /// Sparse codegen rule for number of entries operator. -class ExtractValOpConverter : public OneToNOpConversionPattern { +class ExtractValOpConverter : public OpConversionPattern { public: - using OneToNOpConversionPattern::OneToNOpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ExtractValOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value pos = adaptor.getIterator().back(); - Value valBuf = rewriter.create(loc, op.getTensor()); + Value valBuf = + rewriter.create(loc, getSingleValue(adaptor.getTensor())); rewriter.replaceOpWithNewOp(op, valBuf, pos); return success(); } }; -class SparseIterateOpConverter : public OneToNOpConversionPattern { +class SparseIterateOpConverter : public OpConversionPattern { public: - using OneToNOpConversionPattern::OneToNOpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(IterateOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { if (!op.getCrdUsedLvls().empty()) return rewriter.notifyMatchFailure( op, "non-empty coordinates list not implemented."); @@ -235,14 +248,15 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern { llvm::append_range(ivs, inits); // Type conversion on iterate op block. - OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes()); + unsigned numOrigArgs = op.getBody()->getArgumentTypes().size(); + TypeConverter::SignatureConversion signatureConversion(numOrigArgs); if (failed(typeConverter->convertSignatureArgs( - op.getBody()->getArgumentTypes(), blockTypeMapping))) + op.getBody()->getArgumentTypes(), signatureConversion))) return rewriter.notifyMatchFailure( op, "failed to convert iterate region argurment types"); - rewriter.applySignatureConversion(op.getBody(), blockTypeMapping); - Block *block = op.getBody(); + Block *block = rewriter.applySignatureConversion( + op.getBody(), signatureConversion, getTypeConverter()); ValueRange ret = genLoopWithIterator( rewriter, loc, it.get(), ivs, [block](PatternRewriter &rewriter, Location loc, Region &loopBody, @@ -263,19 +277,17 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern { return result; }); - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - rewriter.replaceOp(op, ret, resultMapping); + rewriter.replaceOp(op, ret); return success(); } }; -class SparseCoIterateOpConverter - : public OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; +class SparseCoIterateOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoIterateOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { assert(op.getSpaceDim() == 1 && "Not implemented"); Location loc = op.getLoc(); @@ -299,18 +311,23 @@ class SparseCoIterateOpConverter assert(!needUniv && "Not implemented"); (void)needUniv; + SmallVector newBlocks; + DenseMap newToOldBlockMap; for (Region ®ion : op.getCaseRegions()) { // Do a one-shot type conversion on all region blocks, since the same // region might be used multiple time. Block *block = ®ion.getBlocks().front(); - OneToNTypeMapping blockTypeMapping(block->getArgumentTypes()); + TypeConverter::SignatureConversion blockTypeMapping( + block->getArgumentTypes().size()); if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), blockTypeMapping))) { return rewriter.notifyMatchFailure( op, "failed to convert coiterate region argurment types"); } - rewriter.applySignatureConversion(block, blockTypeMapping); + newBlocks.push_back(rewriter.applySignatureConversion( + block, blockTypeMapping, getTypeConverter())); + newToOldBlockMap[newBlocks.back()] = block; } SmallVector spaces; @@ -343,7 +360,7 @@ class SparseCoIterateOpConverter // Generates a loop sequence, one loop per case. for (auto [r, caseBits] : - llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) { + llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) { assert(caseBits.count() > 0 && "Complement space not implemented"); // Retrives a vector of pointers to the iterators used in the case. @@ -359,11 +376,17 @@ class SparseCoIterateOpConverter // The subcases are never empty, it must contains at least the current // region itself. // TODO: these cases should be sorted. - SmallVector subCases = op.getSubCasesOf(r.getRegionNumber()); + SmallVector subCases = + op.getSubCasesOf(r->getParent()->getRegionNumber()); + SmallVector newBlocks, oldBlocks; + for (Region *r : subCases) { + newBlocks.push_back(&r->front()); + oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]); + } assert(!subCases.empty()); - ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, - iters, subCases, userReduc); + ValueRange res = genCoIterateBranchNest( + rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc); SmallVector nextIterYields(res); // 2nd. foward the loop. @@ -388,7 +411,7 @@ class SparseCoIterateOpConverter // This is a simple iteration loop. assert(caseBits.count() == 1); - Block *block = &r.getBlocks().front(); + Block *block = r; ValueRange curResult = genLoopWithIterator( rewriter, loc, validIters.front(), userReduc, /*bodyBuilder=*/ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 1cac949b68c79..153b9b170e5d3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -172,11 +172,16 @@ struct LowerSparseIterationToSCFPass ConversionTarget target(*ctx); // The actual conversion. - target.addIllegalOp(); + target.addLegalDialect(); + target.addIllegalOp(); + target.addLegalOp(); populateLowerSparseIterationToSCFPatterns(converter, patterns); - if (failed(applyPartialOneToNConversion(getOperation(), converter, - std::move(patterns)))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } };