diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index f3bf5b66398e0..86ea87b55af1c 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -143,6 +143,8 @@ template class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = + typename SourceOp::template GenericAdaptor>; explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) @@ -153,8 +155,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), OpAdaptor(operands, cast(op)), - rewriter); + auto sourceOp = cast(op); + rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast(op); + rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); } LogicalResult match(Operation *op) const final { return match(cast(op)); @@ -162,8 +169,15 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast(op), - OpAdaptor(operands, cast(op)), rewriter); + auto sourceOp = cast(op); + return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast(op); + return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be @@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override rewrite or matchAndRewrite"); } + virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { rewrite(op, adaptor, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } private: using ConvertToLLVMPattern::match; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index aac6b7c03548a..28150e886913e 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -538,8 +538,15 @@ class ConversionPattern : public RewritePattern { ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite"); } + virtual void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } /// Hook for derived classes to implement combined matching and rewriting. + /// This overload supports only 1:1 replacements. The 1:N overload is called + /// by the driver. By default, it calls this 1:1 overload or reports a fatal + /// error if 1:N replacements were found. virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -549,6 +556,14 @@ class ConversionPattern : public RewritePattern { return success(); } + /// Hook for derived classes to implement combined matching and rewriting. + /// This overload supports 1:N replacements. + virtual LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } + /// Attempt to match and rewrite the IR root at the specified operation. LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final; @@ -575,6 +590,15 @@ class ConversionPattern : public RewritePattern { : RewritePattern(std::forward(args)...), typeConverter(&typeConverter) {} + /// Given an array of value ranges, which are the inputs to a 1:N adaptor, + /// try to extract the single value of each range to construct a the inputs + /// for a 1:1 adaptor. + /// + /// This function produces a fatal error if at least one range has 0 or + /// more than 1 value: "pattern 'name' does not support 1:N conversion" + SmallVector + getOneToOneAdaptorOperands(ArrayRef operands) const; + protected: /// An optional type converter for use by this pattern. const TypeConverter *typeConverter = nullptr; @@ -590,6 +614,8 @@ template class OpConversionPattern : public ConversionPattern { public: using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = + typename SourceOp::template GenericAdaptor>; OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} @@ -608,12 +634,24 @@ class OpConversionPattern : public ConversionPattern { auto sourceOp = cast(op); rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast(op); + rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); + } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto sourceOp = cast(op); return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast(op); + return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -624,6 +662,12 @@ class OpConversionPattern : public ConversionPattern { ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } + virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -632,6 +676,13 @@ class OpConversionPattern : public ConversionPattern { rewrite(op, adaptor, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } private: using ConversionPattern::matchAndRewrite; @@ -657,11 +708,20 @@ class OpInterfaceConversionPattern : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { rewrite(cast(op), operands, rewriter); } + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast(op), operands, rewriter); + } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), operands, rewriter); } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(cast(op), operands, rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -669,6 +729,10 @@ class OpInterfaceConversionPattern : public ConversionPattern { ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } + virtual void rewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -677,6 +741,11 @@ class OpInterfaceConversionPattern : public ConversionPattern { rewrite(op, operands, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } private: using ConversionPattern::matchAndRewrite; diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp index a08764326a80b..03be00328bda3 100644 --- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp @@ -13,40 +13,6 @@ using namespace mlir; using namespace mlir::func; -//===----------------------------------------------------------------------===// -// Helper functions -//===----------------------------------------------------------------------===// - -/// If the given value can be decomposed with the type converter, decompose it. -/// Otherwise, return the given value. -// TODO: Value decomposition should happen automatically through a 1:N adaptor. -// This function will disappear when the 1:1 and 1:N drivers are merged. -static SmallVector decomposeValue(OpBuilder &builder, Location loc, - Value value, - const TypeConverter *converter) { - // Try to convert the given value's type. If that fails, just return the - // given value. - SmallVector convertedTypes; - if (failed(converter->convertType(value.getType(), convertedTypes))) - return {value}; - if (convertedTypes.empty()) - return {}; - - // If the given value's type is already legal, just return the given value. - TypeRange convertedTypeRange(convertedTypes); - if (convertedTypeRange == TypeRange(value.getType())) - return {value}; - - // Try to materialize a target conversion. If the materialization did not - // produce values of the requested type, the materialization failed. Just - // return the given value in that case. - SmallVector result = converter->materializeTargetConversion( - builder, loc, convertedTypeRange, value); - if (result.empty()) - return {value}; - return result; -} - //===----------------------------------------------------------------------===// // DecomposeCallGraphTypesForFuncArgs //===----------------------------------------------------------------------===// @@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { SmallVector newOperands; - for (Value operand : adaptor.getOperands()) { - // TODO: We can directly take the values from the adaptor once this is a - // 1:N conversion pattern. - llvm::append_range(newOperands, - decomposeValue(rewriter, operand.getLoc(), operand, - getTypeConverter())); - } + for (ValueRange operand : adaptor.getOperands()) + llvm::append_range(newOperands, operand); rewriter.replaceOpWithNewOp(op, newOperands); return success(); } @@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CallOp op, OpAdaptor adaptor, + matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // Create the operands list of the new `CallOp`. SmallVector newOperands; - for (Value operand : adaptor.getOperands()) { - // TODO: We can directly take the values from the adaptor once this is a - // 1:N conversion pattern. - llvm::append_range(newOperands, - decomposeValue(rewriter, operand.getLoc(), operand, - getTypeConverter())); - } + for (ValueRange operand : adaptor.getOperands()) + llvm::append_range(newOperands, operand); // Create the new result types for the new `CallOp` and track the number of // replacement types for each original op result. diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp index b1cde6ca5d2fc..9e7759bef6d8f 100644 --- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp @@ -13,6 +13,14 @@ using namespace mlir; using namespace mlir::func; +/// Flatten the given value ranges into a single vector of values. +static SmallVector flattenValues(ArrayRef values) { + SmallVector result; + for (const auto &vals : values) + llvm::append_range(result, vals); + return result; +} + namespace { /// Converts the operand and result types of the CallOp, used together with the /// FuncOpSignatureConversion. @@ -21,7 +29,7 @@ struct CallOpSignatureConversion : public OpConversionPattern { /// Hook for derived classes to implement combined matching and rewriting. LogicalResult - matchAndRewrite(CallOp callOp, OpAdaptor adaptor, + matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Convert the original function results. Keep track of how many result // types an original result type is converted into. @@ -38,9 +46,9 @@ struct CallOpSignatureConversion : public OpConversionPattern { // Substitute with the new result types from the corresponding FuncType // conversion. - auto newCallOp = - rewriter.create(callOp.getLoc(), callOp.getCallee(), - convertedResults, adaptor.getOperands()); + auto newCallOp = rewriter.create( + callOp.getLoc(), callOp.getCallee(), convertedResults, + flattenValues(adaptor.getOperands())); SmallVector replacements; size_t offset = 0; for (int i = 0, e = callOp->getNumResults(); i < e; ++i) { diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 93a78056db194..c0589044c26ec 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -16,20 +16,18 @@ using namespace mlir::scf; namespace { -// Unpacks the single unrealized_conversion_cast using the list of inputs -// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d) -static void unpackUnrealizedConversionCast(Value v, - SmallVectorImpl &unpacked) { - if (auto cast = - dyn_cast_or_null(v.getDefiningOp())) { - if (cast.getInputs().size() != 1) { - // 1 : N type conversion. - unpacked.append(cast.getInputs().begin(), cast.getInputs().end()); - return; - } - } - // 1 : 1 type conversion. - unpacked.push_back(v); +/// Flatten the given value ranges into a single vector of values. +static SmallVector flattenValues(ArrayRef values) { + SmallVector result; + for (const auto &vals : values) + llvm::append_range(result, vals); + return result; +} + +/// Assert that the given value range contains a single value and return it. +static Value getSingleValue(ValueRange values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); } // CRTP @@ -40,19 +38,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern { public: using OpConversionPattern::typeConverter; using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OpConversionPattern::OpAdaptor; + using OneToNOpAdaptor = + typename OpConversionPattern::OneToNOpAdaptor; // // Derived classes should provide the following method which performs the // actual conversion. It should return std::nullopt upon conversion failure // and return the converted operation upon success. // - // std::optional convertSourceOp(SourceOp op, OpAdaptor adaptor, - // ConversionPatternRewriter &rewriter, - // TypeRange dstTypes) const; + // std::optional convertSourceOp( + // SourceOp op, OneToNOpAdaptor adaptor, + // ConversionPatternRewriter &rewriter, + // TypeRange dstTypes) const; LogicalResult - matchAndRewrite(SourceOp op, OpAdaptor adaptor, + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector dstTypes; SmallVector offsets; @@ -73,28 +73,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "could not convert operation"); // Packs the return value. - SmallVector packedRets; + SmallVector packedRets; for (unsigned i = 1, e = offsets.size(); i < e; i++) { unsigned start = offsets[i - 1], end = offsets[i]; unsigned len = end - start; ValueRange mappedValue = newOp->getResults().slice(start, len); - if (len != 1) { - // 1 : N type conversion. - Type origType = op.getResultTypes()[i - 1]; - Value mat = typeConverter->materializeSourceConversion( - rewriter, op.getLoc(), origType, mappedValue); - if (!mat) { - return rewriter.notifyMatchFailure( - op, "Failed to materialize 1:N type conversion"); - } - packedRets.push_back(mat); - } else { - // 1 : 1 type conversion. - packedRets.push_back(mappedValue.front()); - } + packedRets.push_back(mappedValue); } - rewriter.replaceOp(op, packedRets); + rewriter.replaceOpWithMultiple(op, packedRets); return success(); } }; @@ -105,7 +92,7 @@ class ConvertForOpTypes using Structural1ToNConversionPattern::Structural1ToNConversionPattern; // The callback required by CRTP. - std::optional convertSourceOp(ForOp op, OpAdaptor adaptor, + std::optional convertSourceOp(ForOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { // Create a empty new op and inline the regions from the old op. @@ -129,16 +116,13 @@ class ConvertForOpTypes if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter))) return std::nullopt; - // Unpacked the iteration arguments. - SmallVector flatArgs; - for (Value arg : adaptor.getInitArgs()) - unpackUnrealizedConversionCast(arg, flatArgs); - // We can not do clone as the number of result types after conversion // might be different. - ForOp newOp = rewriter.create(op.getLoc(), adaptor.getLowerBound(), - adaptor.getUpperBound(), - adaptor.getStep(), flatArgs); + ForOp newOp = rewriter.create( + op.getLoc(), getSingleValue(adaptor.getLowerBound()), + getSingleValue(adaptor.getUpperBound()), + getSingleValue(adaptor.getStep()), + flattenValues(adaptor.getInitArgs())); // Reserve whatever attributes in the original op. newOp->setAttrs(op->getAttrs()); @@ -160,12 +144,12 @@ class ConvertIfOpTypes public: using Structural1ToNConversionPattern::Structural1ToNConversionPattern; - std::optional convertSourceOp(IfOp op, OpAdaptor adaptor, + std::optional convertSourceOp(IfOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - IfOp newOp = rewriter.create(op.getLoc(), dstTypes, - adaptor.getCondition(), true); + IfOp newOp = rewriter.create( + op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true); newOp->setAttrs(op->getAttrs()); // We do not need the empty blocks created by rewriter. @@ -189,15 +173,11 @@ class ConvertWhileOpTypes public: using Structural1ToNConversionPattern::Structural1ToNConversionPattern; - std::optional convertSourceOp(WhileOp op, OpAdaptor adaptor, + std::optional convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - // Unpacked the iteration arguments. - SmallVector flatArgs; - for (Value arg : adaptor.getOperands()) - unpackUnrealizedConversionCast(arg, flatArgs); - - auto newOp = rewriter.create(op.getLoc(), dstTypes, flatArgs); + auto newOp = rewriter.create(op.getLoc(), dstTypes, + flattenValues(adaptor.getOperands())); for (auto i : {0u, 1u}) { if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) @@ -218,13 +198,10 @@ class ConvertYieldOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector unpackedYield; - for (Value operand : adaptor.getOperands()) - unpackUnrealizedConversionCast(operand, unpackedYield); - - rewriter.replaceOpWithNewOp(op, unpackedYield); + rewriter.replaceOpWithNewOp( + op, flattenValues(adaptor.getOperands())); return success(); } }; @@ -235,13 +212,10 @@ class ConvertConditionOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConditionOp op, OpAdaptor adaptor, + matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector unpackedYield; - for (Value operand : adaptor.getOperands()) - unpackUnrealizedConversionCast(operand, unpackedYield); - - rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); }); + rewriter.modifyOpInPlace( + op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); }); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 25fca49cb0154..20d46f7ca00c5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -39,25 +39,18 @@ using namespace mlir::sparse_tensor; // Helper methods. //===----------------------------------------------------------------------===// -/// Flattens a list of operands that may contain sparse tensors. -static void flattenOperands(ValueRange operands, - SmallVectorImpl &flattened) { - // In case of - // sparse_tensor, c, sparse_tensor - // ==> - // memref ..., c, memref ... - for (auto operand : operands) { - if (getSparseTensorEncoding(operand.getType())) { - auto tuple = getTuple(operand); - // An unrealized_conversion_cast will be inserted by type converter to - // inter-mix the gap between 1:N conversion between sparse tensors and - // fields. In this case, take the operands in the cast and replace the - // sparse tensor output with the flattened type array. - flattened.append(tuple.getOperands().begin(), tuple.getOperands().end()); - } else { - flattened.push_back(operand); - } - } +/// Flatten the given value ranges into a single vector of values. +static SmallVector flattenValues(ArrayRef values) { + SmallVector result; + for (const auto &vals : values) + llvm::append_range(result, vals); + return result; +} + +/// Assert that the given value range contains a single value and return it. +static Value getSingleValue(ValueRange values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); } /// Generates a load with proper `index` typing. @@ -567,12 +560,11 @@ class SparseReturnConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector flattened; - flattenOperands(adaptor.getOperands(), flattened); // Create a return with the flattened value extracted from sparse tensors. - rewriter.replaceOpWithNewOp(op, flattened); + rewriter.replaceOpWithNewOp( + op, flattenValues(adaptor.getOperands())); return success(); } }; @@ -583,7 +575,7 @@ class SparseCallConverter : public OpConversionPattern { // The default CallOp converter can not handle 1:N type conversion. using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); // In case of: @@ -596,10 +588,8 @@ class SparseCallConverter : public OpConversionPattern { return failure(); // (1) Generates new call with flattened return value. - SmallVector flattened; - flattenOperands(adaptor.getOperands(), flattened); - auto newCall = rewriter.create(loc, op.getCallee(), - finalRetTy, flattened); + auto newCall = rewriter.create( + loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands())); // (2) Gather sparse tensor returns. SmallVector> packedResultVals; // Tracks the offset of current return value (of the original call) @@ -643,7 +633,7 @@ class SparseLvlOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(LvlOp op, OpAdaptor adaptor, + matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { std::optional lvl = op.getConstantLvlIndex(); RankedTensorType srcType = op.getSource().getType(); @@ -662,7 +652,7 @@ class SparseLvlOpConverter : public OpConversionPattern { struct SparseReorderCOOConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor, + matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); @@ -693,7 +683,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern { // Since we do in-place sorting, the destinate tensor will have the same set // of memrefs as the source tensor. - rewriter.replaceOp(op, adaptor.getInputCoo()); + rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()}); return success(); } }; @@ -702,8 +692,10 @@ template class SparseSliceGetterOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + using typename OpConversionPattern::OneToNOpAdaptor; + LogicalResult - matchAndRewrite(Op op, typename Op::Adaptor adaptor, + matchAndRewrite(Op op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Simply lowers to specifer.get operation. auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(), @@ -721,14 +713,14 @@ class SparseCastConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only rewrite identically annotated source/dest. auto encDst = getSparseTensorEncoding(op.getType()); auto encSrc = getSparseTensorEncoding(op.getSource().getType()); if (!encDst || encDst != encSrc) return failure(); - rewriter.replaceOp(op, adaptor.getOperands()); + rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); return success(); } }; @@ -737,10 +729,10 @@ class SparseReMapConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, + matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Simply fold the operation. - rewriter.replaceOp(op, adaptor.getSource()); + rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); return success(); } }; @@ -756,7 +748,7 @@ class SparseTensorAllocConverter enableBufferInitialization(enableInit) {} LogicalResult - matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, + matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const auto resType = getSparseTensorType(op); if (!resType.hasEncoding()) @@ -791,7 +783,8 @@ class SparseTensorAllocConverter } // Level size equals to dimension size since lvl2dim map is an identity map. SmallVector lvlSizesValues; - createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(), + createDimSizes(rewriter, loc, resType, + flattenValues(adaptor.getDynamicSizes()), /*dimSizesValues=*/lvlSizesValues); // Construct allocation for each field. @@ -861,7 +854,7 @@ class SparseTensorDeallocConverter createDeallocs(createDeallocs) {} LogicalResult - matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, + matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto enc = getSparseTensorEncoding(op.getTensor().getType()); if (!enc) @@ -892,7 +885,7 @@ class SparseTensorLoadConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(LoadOp op, OpAdaptor adaptor, + matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Prepare descriptor. auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), @@ -911,7 +904,7 @@ class SparseExpandConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ExpandOp op, OpAdaptor adaptor, + matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!getSparseTensorEncoding(op.getTensor().getType())) return failure(); @@ -963,16 +956,16 @@ class SparseCompressConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CompressOp op, OpAdaptor adaptor, + matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); SmallVector fields; auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields, op.getTensor().getType()); - Value values = adaptor.getValues(); - Value filled = adaptor.getFilled(); - Value added = adaptor.getAdded(); - Value count = adaptor.getCount(); + Value values = getSingleValue(adaptor.getValues()); + Value filled = getSingleValue(adaptor.getFilled()); + Value added = getSingleValue(adaptor.getAdded()); + Value count = getSingleValue(adaptor.getCount()); const SparseTensorType dstType(desc.getRankedTensorType()); Type eltType = dstType.getElementType(); @@ -1005,7 +998,8 @@ class SparseCompressConverter : public OpConversionPattern { SmallVector params(desc.getFields().begin(), desc.getFields().end()); SmallVector flatSpTensorTps = llvm::to_vector( llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); })); - params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end()); + SmallVector flatLvlCoords = flattenValues(adaptor.getLvlCoords()); + params.append(flatLvlCoords.begin(), flatLvlCoords.end()); params.push_back(crd); params.push_back(value); SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, @@ -1033,9 +1027,9 @@ class SparseInsertConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto stt = getSparseTensorType(adaptor.getDest()); + auto stt = getSparseTensorType(op.getDest()); if (!stt.hasEncoding()) return failure(); assert(stt.isIdentity() && "Run reinterpret-map before conversion."); @@ -1045,8 +1039,9 @@ class SparseInsertConverter : public OpConversionPattern { getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType()); TypeRange flatSpTensorTps = desc.getFields().getTypes(); SmallVector params = llvm::to_vector(desc.getFields()); - params.append(adaptor.getIndices().begin(), adaptor.getIndices().end()); - params.push_back(adaptor.getScalar()); + SmallVector flatIndices = flattenValues(adaptor.getIndices()); + params.append(flatIndices.begin(), flatIndices.end()); + params.push_back(getSingleValue(adaptor.getScalar())); SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps, params, /*genCall=*/true); SmallVector ret = insertGen.genCallOrInline(rewriter, loc); @@ -1062,7 +1057,7 @@ class SparseToPositionsConverter : public OpConversionPattern { using OpAdaptor = typename ToPositionsOp::Adaptor; using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, + matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested position access with corresponding field. // The view is restricted to the actual size to ensure clients @@ -1085,7 +1080,7 @@ class SparseToCoordinatesConverter using OpAdaptor = typename ToCoordinatesOp::Adaptor; using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, + matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested coordinates access with corresponding field. // The view is restricted to the actual size to ensure clients @@ -1111,7 +1106,7 @@ class SparseToCoordinatesBufferConverter using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor, + matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested coordinates access with corresponding field. // The view is restricted to the actual size to ensure clients @@ -1133,7 +1128,7 @@ class SparseToValuesConverter : public OpConversionPattern { using OpAdaptor = typename ToValuesOp::Adaptor; using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, + matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested values access with corresponding field. // The view is restricted to the actual size to ensure clients @@ -1153,7 +1148,7 @@ class SparseConvertConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConvertOp op, OpAdaptor adaptor, + matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); SparseTensorEncodingAttr encSrc = @@ -1173,7 +1168,7 @@ class SparseConvertConverter : public OpConversionPattern { Type srcElemTp = op.getSource().getType().getElementType(); // Fold the trivial cases. if (retElemTp == srcElemTp && encDst == encSrc) { - rewriter.replaceOp(op, adaptor.getSource()); + rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); return success(); } // @@ -1239,7 +1234,7 @@ class SparseExtractSliceConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); @@ -1296,7 +1291,7 @@ class SparseNumberOfEntriesConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, + matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values. // FIXME: the nse value computed in this way might be wrong when there is @@ -1430,7 +1425,7 @@ struct SparseDisassembleOpConverter : OpConversionPattern(typeConverter, context) {} LogicalResult - matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, + matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), op.getTensor().getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h index 89858546e37e1..869c7864d7535 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h @@ -228,11 +228,6 @@ class MutSparseTensorDescriptor } }; -/// Returns the "tuple" value of the adapted tensor. -inline UnrealizedConversionCastOp getTuple(Value tensor) { - return llvm::cast(tensor.getDefiningOp()); -} - /// Packs the given values as a "tuple" value. inline Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values) { @@ -246,16 +241,15 @@ inline Value genTuple(OpBuilder &builder, Location loc, } inline SparseTensorDescriptor -getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) { - auto tuple = getTuple(tensor); - return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs()); +getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) { + return SparseTensorDescriptor(SparseTensorType(type), adaptorValues); } inline MutSparseTensorDescriptor -getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields, +getMutDescriptorFromTensorTuple(ValueRange adaptorValues, + SmallVectorImpl &fields, RankedTensorType type) { - auto tuple = getTuple(tensor); - fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); + fields.assign(adaptorValues.begin(), adaptorValues.end()); return MutSparseTensorDescriptor(SparseTensorType(type), fields); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 1424c4974f2d4..613fd6d9d74b1 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -67,10 +67,6 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) { // ConversionValueMapping //===----------------------------------------------------------------------===// -/// A list of replacement SSA values. Optimized for the common case of a single -/// SSA value. -using ReplacementValues = SmallVector; - namespace { /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. @@ -783,7 +779,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { LogicalResult remapValues(StringRef valueDiagTag, std::optional inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl &remapped); + SmallVector> &remapped); /// Return "true" if the given operation is ignored, and does not need to be /// converted. @@ -817,17 +813,31 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // Materializations //===--------------------------------------------------------------------===// - /// Build an unresolved materialization operation given an output type and set - /// of input operands. + /// Build an unresolved materialization operation given a range of output + /// types and a list of input operands. Returns the inputs if they their + /// types match the output types. + /// + /// If a cast op was built, it can optionally be returned with the `castOp` + /// output argument. /// /// If `valueToMap` is set to a non-null Value, then that value is mapped to - /// the result of the unresolved materialization in the conversion value + /// the results of the unresolved materialization in the conversion value /// mapping. - Value buildUnresolvedMaterialization(MaterializationKind kind, - OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, - Type outputType, Type originalType, - const TypeConverter *converter); + ValueRange buildUnresolvedMaterialization( + MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, + Value valueToMap, ValueRange inputs, TypeRange outputTypes, + Type originalType, const TypeConverter *converter, + UnrealizedConversionCastOp *castOp = nullptr); + Value buildUnresolvedMaterialization( + MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, + Value valueToMap, ValueRange inputs, Type outputType, Type originalType, + const TypeConverter *converter, + UnrealizedConversionCastOp *castOp = nullptr) { + return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs, + TypeRange(outputType), originalType, + converter, castOp) + .front(); + } /// Build an N:1 materialization for the given original value that was /// replaced with the given replacement values. @@ -853,6 +863,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { Value findOrBuildReplacementValue(Value value, const TypeConverter *converter); + /// Unpack an N:1 materialization and return the inputs of the + /// materialization. This function unpacks only those materializations that + /// were built with `insertNTo1Materialization`. + /// + /// This is a workaround around incomplete 1:N support in the dialect + /// conversion driver. It allows us to write 1:N conversion patterns while + /// 1:N support is still missing in the conversion value mapping. This + /// function will be deleted when full 1:N support has been added. + SmallVector unpackNTo1Materialization(Value value); + //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -862,7 +882,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { OpBuilder::InsertPoint previous) override; /// Notifies that an op is about to be replaced with the given values. - void notifyOpReplaced(Operation *op, ArrayRef newValues); + void notifyOpReplaced(Operation *op, ArrayRef newValues); /// Notifies that a block is about to be erased. void notifyBlockIsBeingErased(Block *block); @@ -955,6 +975,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { DenseMap unresolvedMaterializations; + /// A set of all N:1 materializations that were added to work around + /// incomplete 1:N support in the dialect conversion driver. + DenseSet nTo1TempMaterializations; + /// The current type converter, or nullptr if no type converter is currently /// active. const TypeConverter *currentTypeConverter = nullptr; @@ -1091,6 +1115,7 @@ void UnresolvedMaterializationRewrite::rollback() { if (mappedValue) rewriterImpl.mapping.erase(mappedValue); rewriterImpl.unresolvedMaterializations.erase(getOperation()); + rewriterImpl.nTo1TempMaterializations.erase(getOperation()); op->erase(); } @@ -1136,7 +1161,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { LogicalResult ConversionPatternRewriterImpl::remapValues( StringRef valueDiagTag, std::optional inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl &remapped) { + SmallVector> &remapped) { remapped.reserve(llvm::size(values)); for (const auto &it : llvm::enumerate(values)) { @@ -1144,11 +1169,18 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Type origType = operand.getType(); Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); + // Find the most recently mapped value. Unpack all temporary N:1 + // materializations. Such conversions are a workaround around missing + // 1:N support in the ConversionValueMapping. (The conversion patterns + // already support 1:N replacements.) + Value repl = mapping.lookupOrDefault(operand); + SmallVector unpacked = unpackNTo1Materialization(repl); + if (!currentTypeConverter) { // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped value. - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back(std::move(unpacked)); continue; } @@ -1162,15 +1194,29 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( return failure(); } + // If a type is converted to 0 types, there is nothing to do. + if (legalTypes.empty()) { + remapped.push_back({}); + continue; + } + if (legalTypes.size() != 1) { - // TODO: Parts of the dialect conversion infrastructure do not support - // 1->N type conversions yet. Therefore, if a type is converted to 0 or - // multiple types, the only thing that we can do for now is passing - // through the most recently mapped value. Fixing this requires - // improvements to the `ConversionValueMapping` (to be able to store 1:N - // mappings) and to the `ConversionPattern` adaptor handling (to be able - // to pass multiple remapped values for a single operand to the adaptor). - remapped.push_back(mapping.lookupOrDefault(operand)); + // TODO: This is a 1:N conversion. The conversion value mapping does not + // store such materializations yet. If the types of the most recently + // mapped values do not match, build a target materialization. + if (TypeRange(unpacked) == legalTypes) { + remapped.push_back(std::move(unpacked)); + continue; + } + + // Insert a target materialization if the current pattern expects + // different legalized types. + ValueRange targetMat = buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(repl), operandLoc, + /*valueToMap=*/Value(), /*inputs=*/unpacked, + /*outputType=*/legalTypes, /*originalType=*/origType, + currentTypeConverter); + remapped.push_back(targetMat); continue; } @@ -1182,15 +1228,15 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( if (newOperand.getType() != desiredType) { // If the looked up value's type does not have the desired type, it means // that the value was replaced with a value of different type and no - // source materialization was created yet. + // target materialization was created yet. Value castValue = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(newOperand), - operandLoc, /*valueToMap=*/newOperand, /*inputs=*/newOperand, + operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked, /*outputType=*/desiredType, /*originalType=*/origType, currentTypeConverter); newOperand = castValue; } - remapped.push_back(newOperand); + remapped.push_back({newOperand}); } return success(); } @@ -1347,31 +1393,38 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// Build an unresolved materialization operation given an output type and set /// of input operands. -Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( +ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, Type outputType, Type originalType, - const TypeConverter *converter) { + Value valueToMap, ValueRange inputs, TypeRange outputTypes, + Type originalType, const TypeConverter *converter, + UnrealizedConversionCastOp *castOp) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); // Avoid materializing an unnecessary cast. - if (inputs.size() == 1 && inputs.front().getType() == outputType) { - if (valueToMap) + if (TypeRange(inputs) == outputTypes) { + if (valueToMap) { + assert(inputs.size() == 1 && "1:N mapping is not supported"); mapping.map(valueToMap, inputs.front()); - return inputs.front(); + } + return inputs; } // Create an unresolved materialization. We use a new OpBuilder to avoid // tracking the materialization like we do for other operations. - OpBuilder builder(outputType.getContext()); + OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = - builder.create(loc, outputType, inputs); - if (valueToMap) + builder.create(loc, outputTypes, inputs); + if (valueToMap) { + assert(outputTypes.size() == 1 && "1:N mapping is not supported"); mapping.map(valueToMap, convertOp.getResult(0)); + } + if (castOp) + *castOp = convertOp; appendRewrite(convertOp, converter, kind, originalType, valueToMap); - return convertOp.getResult(0); + return convertOp.getResults(); } void ConversionPatternRewriterImpl::insertNTo1Materialization( @@ -1379,10 +1432,13 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization( Value originalValue, const TypeConverter *converter) { // Insert argument materialization back to the original type. Type originalType = originalValue.getType(); + UnrealizedConversionCastOp argCastOp; Value argMat = buildUnresolvedMaterialization( MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue, - /*inputs=*/replacements, originalType, /*originalType=*/Type(), - converter); + /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter, + &argCastOp); + if (argCastOp) + nTo1TempMaterializations.insert(argCastOp); // Insert target materialization to the legalized type. Type legalOutputType; @@ -1398,11 +1454,14 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization( legalOutputType = replacements[0].getType(); } if (legalOutputType && legalOutputType != originalType) { - buildUnresolvedMaterialization(MaterializationKind::Target, - computeInsertPoint(argMat), loc, - /*valueToMap=*/argMat, /*inputs=*/argMat, - /*outputType=*/legalOutputType, - /*originalType=*/originalType, converter); + UnrealizedConversionCastOp targetCastOp; + buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(argMat), loc, + /*valueToMap=*/argMat, /*inputs=*/argMat, + /*outputType=*/legalOutputType, /*originalType=*/originalType, + converter, &targetCastOp); + if (targetCastOp) + nTo1TempMaterializations.insert(targetCastOp); } } @@ -1438,9 +1497,32 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(), /*originalType=*/Type(), converter); + mapping.map(value, castValue); return castValue; } +SmallVector +ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) { + // Unpack unrealized_conversion_cast ops that were inserted as a N:1 + // workaround. + auto castOp = value.getDefiningOp(); + if (!castOp) + return {value}; + if (!nTo1TempMaterializations.contains(castOp)) + return {value}; + assert(castOp->getNumResults() == 1 && "expected single result"); + + SmallVector result; + for (Value v : castOp.getOperands()) { + // Keep unpacking if possible. This is needed because during block + // signature conversions and 1:N op replacements, the driver may have + // inserted two materializations back-to-back: first an argument + // materialization, then a target materialization. + llvm::append_range(result, unpackNTo1Materialization(v)); + } + return result; +} + //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1465,7 +1547,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( } void ConversionPatternRewriterImpl::notifyOpReplaced( - Operation *op, ArrayRef newValues) { + Operation *op, ArrayRef newValues) { assert(newValues.size() == op->getNumResults()); assert(!ignoredOps.contains(op) && "operation was already replaced"); @@ -1477,8 +1559,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( isUnresolvedMaterialization = true; // Create mappings for each of the new result values. - for (auto [n, result] : llvm::zip_equal(newValues, op->getResults())) { - ReplacementValues repl = n; + for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) { if (repl.empty()) { // This result was dropped and no replacement value was provided. if (isUnresolvedMaterialization) { @@ -1488,12 +1569,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( } // Materialize a replacement value "out of thin air". - Value sourceMat = buildUnresolvedMaterialization( + buildUnresolvedMaterialization( MaterializationKind::Source, computeInsertPoint(result), - result.getLoc(), /*valueToMap=*/Value(), /*inputs=*/ValueRange(), + result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(), /*outputType=*/result.getType(), /*originalType=*/Type(), currentTypeConverter); - repl.push_back(sourceMat); + continue; } else { // Make sure that the user does not mess with unresolved materializations // that were inserted by the conversion driver. We keep track of these @@ -1595,10 +1676,9 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - SmallVector newVals(newValues.size()); - for (auto [index, val] : llvm::enumerate(newValues)) - if (val) - newVals[index].push_back(val); + SmallVector newVals; + for (int i = 0; i < newValues.size(); ++i) + newVals.push_back(newValues.slice(i, 1)); impl->notifyOpReplaced(op, newVals); } @@ -1610,10 +1690,7 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - SmallVector newVals(newValues.size(), {}); - for (auto [index, val] : llvm::enumerate(newValues)) - llvm::append_range(newVals[index], val); - impl->notifyOpReplaced(op, newVals); + impl->notifyOpReplaced(op, newValues); } void ConversionPatternRewriter::eraseOp(Operation *op) { @@ -1621,7 +1698,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); - SmallVector nullRepls(op->getNumResults(), {}); + SmallVector nullRepls(op->getNumResults(), {}); impl->notifyOpReplaced(op, nullRepls); } @@ -1673,11 +1750,12 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, } Value ConversionPatternRewriter::getRemappedValue(Value key) { - SmallVector remappedValues; + SmallVector> remappedValues; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, remappedValues))) return nullptr; - return remappedValues.front(); + assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); + return remappedValues.front().front(); } LogicalResult @@ -1685,8 +1763,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, SmallVectorImpl &results) { if (keys.empty()) return success(); - return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, - results); + SmallVector> remapped; + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, + remapped))) + return failure(); + for (const auto &values : remapped) { + assert(values.size() == 1 && "1:N conversion not supported"); + results.push_back(values.front()); + } + return success(); } void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, @@ -1780,6 +1865,19 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { // ConversionPattern //===----------------------------------------------------------------------===// +SmallVector ConversionPattern::getOneToOneAdaptorOperands( + ArrayRef operands) const { + SmallVector oneToOneOperands; + oneToOneOperands.reserve(operands.size()); + for (ValueRange operand : operands) { + if (operand.size() != 1) + llvm::report_fatal_error("pattern '" + getDebugName() + + "' does not support 1:N conversion"); + oneToOneOperands.push_back(operand.front()); + } + return oneToOneOperands; +} + LogicalResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { @@ -1791,12 +1889,14 @@ ConversionPattern::matchAndRewrite(Operation *op, getTypeConverter()); // Remap the operands of the operation. - SmallVector operands; + SmallVector> remapped; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, - op->getOperands(), operands))) { + op->getOperands(), remapped))) { return failure(); } - return matchAndRewrite(op, operands, dialectRewriter); + SmallVector remappedAsRange = + llvm::to_vector_of(remapped); + return matchAndRewrite(op, remappedAsRange, dialectRewriter); } //===----------------------------------------------------------------------===// @@ -2536,45 +2636,52 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, assert(!op.use_empty() && "expected that dead materializations have already been DCE'd"); Operation::operand_range inputOperands = op.getOperands(); - Type outputType = op.getResultTypes()[0]; // Try to materialize the conversion. if (const TypeConverter *converter = rewrite->getConverter()) { rewriter.setInsertionPoint(op); - Value newMaterialization; + SmallVector newMaterialization; switch (rewrite->getMaterializationKind()) { - case MaterializationKind::Argument: + case MaterializationKind::Argument: { // Try to materialize an argument conversion. - newMaterialization = converter->materializeArgumentConversion( - rewriter, op->getLoc(), outputType, inputOperands); - if (newMaterialization) + assert(op->getNumResults() == 1 && "expected single result"); + Value argMat = converter->materializeArgumentConversion( + rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); + if (argMat) { + newMaterialization.push_back(argMat); break; + } + } // If an argument materialization failed, fallback to trying a target // materialization. [[fallthrough]]; case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( - rewriter, op->getLoc(), outputType, inputOperands, + rewriter, op->getLoc(), op.getResultTypes(), inputOperands, rewrite->getOriginalType()); break; case MaterializationKind::Source: - newMaterialization = converter->materializeSourceConversion( - rewriter, op->getLoc(), outputType, inputOperands); + assert(op->getNumResults() == 1 && "expected single result"); + Value sourceMat = converter->materializeSourceConversion( + rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); + if (sourceMat) + newMaterialization.push_back(sourceMat); break; } - if (newMaterialization) { - assert(newMaterialization.getType() == outputType && + if (!newMaterialization.empty()) { + assert(TypeRange(newMaterialization) == op.getResultTypes() && "materialization callback produced value of incorrect type"); rewriter.replaceOp(op, newMaterialization); return success(); } } - InFlightDiagnostic diag = - op->emitError() << "failed to legalize unresolved materialization " - "from (" - << inputOperands.getTypes() << ") to (" << outputType - << ") that remained live after conversion"; + InFlightDiagnostic diag = op->emitError() + << "failed to legalize unresolved materialization " + "from (" + << inputOperands.getTypes() << ") to (" + << op.getResultTypes() + << ") that remained live after conversion"; diag.attachNote(op->getUsers().begin()->getLoc()) << "see existing live user here: " << *op->getUsers().begin(); return failure(); diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir index b8fad63eb4de6..4e641317ac2f3 100644 --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -9,10 +9,7 @@ // CHECK-LABEL: func @identity( // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple -// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple) -> i1 -// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple) -> i32 -// CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i32 // CHECK-12N-LABEL: func @identity( // CHECK-12N-SAME: %[[ARG0:.*]]: i1, // CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { @@ -56,18 +53,7 @@ func.func @recursive_decomposition(%arg0: tuple>>) -> tuple (i1, i2) { -// CHECK: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> -// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple -// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple -// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple) -> tuple> -// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple, tuple>) -> tuple, tuple, tuple>> -// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple, tuple, tuple>>) -> tuple<> -// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple, tuple, tuple>>) -> tuple -// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple) -> i1 -// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple, tuple, tuple>>) -> tuple> -// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple>) -> tuple -// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple) -> i2 -// CHECK: return %[[V7]], %[[V10]] : i1, i2 +// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i2 // CHECK-12N-LABEL: func @mixed_recursive_decomposition( // CHECK-12N-SAME: %[[ARG0:.*]]: i1, // CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { @@ -87,14 +73,8 @@ func.func private @callee(tuple) -> tuple // CHECK-LABEL: func @caller( // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple -// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple) -> i1 -// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple) -> i32 -// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32) -// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple -// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple) -> i1 -// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple) -> i32 -// CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) +// CHECK: return %[[V0]]#0, %[[V0]]#1 : i1, i32 // CHECK-12N-LABEL: func @caller( // CHECK-12N-SAME: %[[ARG0:.*]]: i1, // CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { @@ -190,14 +170,8 @@ func.func private @callee(tuple<>, i1, tuple, i3, tuple, i6) -> (tup // CHECK-SAME: %[[I4:.*]]: i4, // CHECK-SAME: %[[I5:.*]]: i5, // CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { -// CHECK: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple -// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple) -> i4 -// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple) -> i5 -// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) -// CHECK: %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple -// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple) -> i4 -// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple) -> i5 -// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 +// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 // CHECK-12N-LABEL: func @caller( // CHECK-12N-SAME: %[[I1:.*]]: i1, // CHECK-12N-SAME: %[[I2:.*]]: i2, diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index e05f444afa68f..d98a6a036e6b1 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -472,3 +472,14 @@ func.func @circular_mapping() { %0 = "test.erase_op"() : () -> (i64) "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> () } + +// ----- + +func.func @test_1_to_n_block_signature_conversion() { + "test.duplicate_block_args"() ({ + ^bb0(%arg0: i64): + "test.repetitive_1_to_n_consumer"(%arg0) : (i64) -> () + }) {} : () -> () + "test.return"() : () -> () +} + diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 239d529218026..d24d52f356d88 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1886,6 +1886,11 @@ def LegalOpC : TEST_Op<"legal_op_c">, Arguments<(ins I32)>, Results<(outs I32)>; def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>; +def DuplicateBlockArgsOp : TEST_Op<"duplicate_block_args", [SingleBlock]> { + let arguments = (ins UnitAttr:$is_legal); + let regions = (region SizedRegion<1>:$body); +} + // Check that the conversion infrastructure can properly undo the creation of // operations where an operation was created before its parent, in this case, // in the parent's builder. diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index bbd55938718fe..8a0bc597c56be 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -982,9 +982,25 @@ struct TestPassthroughInvalidOp : public ConversionPattern { TestPassthroughInvalidOp(MLIRContext *ctx) : ConversionPattern("test.invalid", 1, ctx) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op, std::nullopt, operands, + SmallVector flattened; + for (auto it : llvm::enumerate(operands)) { + ValueRange range = it.value(); + if (range.size() == 1) { + flattened.push_back(range.front()); + continue; + } + + // This is a 1:N replacement. Insert a test.cast op. (That's what the + // argument materialization used to do.) + flattened.push_back( + rewriter + .create(op->getLoc(), + op->getOperand(it.index()).getType(), range) + .getResult()); + } + rewriter.replaceOpWithNewOp(op, std::nullopt, flattened, std::nullopt); return success(); } @@ -1010,23 +1026,13 @@ struct TestSplitReturnType : public ConversionPattern { TestSplitReturnType(MLIRContext *ctx) : ConversionPattern("test.return", 1, ctx) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Check for a return of F32. if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) return failure(); - - // Check if the first operation is a cast operation, if it is we use the - // results directly. - auto *defOp = operands[0].getDefiningOp(); - if (auto packerOp = - llvm::dyn_cast_or_null(defOp)) { - rewriter.replaceOpWithNewOp(op, packerOp.getOperands()); - return success(); - } - - // Otherwise, fail to match. - return failure(); + rewriter.replaceOpWithNewOp(op, operands[0]); + return success(); } }; @@ -1181,6 +1187,47 @@ class TestEraseOp : public ConversionPattern { } }; +/// This pattern matches a test.duplicate_block_args op and duplicates all +/// block arguments. +class TestDuplicateBlockArgs + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getIsLegal()) + return failure(); + rewriter.startOpModification(op); + Block *body = &op.getBody().front(); + TypeConverter::SignatureConversion result(body->getNumArguments()); + for (auto it : llvm::enumerate(body->getArgumentTypes())) + result.addInputs(it.index(), {it.value(), it.value()}); + rewriter.applySignatureConversion(body, result, getTypeConverter()); + op.setIsLegal(true); + rewriter.finalizeOpModification(op); + return success(); + } +}; + +/// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid +/// op. The pattern supports 1:N replacements and forwards the replacement +/// values of the single operand as test.valid operands. +class TestRepetitive1ToNConsumer : public ConversionPattern { +public: + TestRepetitive1ToNConsumer(MLIRContext *ctx) + : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // A single operand is expected. + if (op->getNumOperands() != 1) + return failure(); + rewriter.replaceOpWithNewOp(op, operands.front()); + return success(); + } +}; + } // namespace namespace { @@ -1263,9 +1310,11 @@ struct TestLegalizePatternDriver TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, - TestUndoPropertiesModification, TestEraseOp>(&getContext()); + TestUndoPropertiesModification, TestEraseOp, + TestRepetitive1ToNConsumer>(&getContext()); patterns.add( &getContext(), converter); + patterns.add(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); @@ -1317,6 +1366,9 @@ struct TestLegalizePatternDriver target.addDynamicallyLegalOp( [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); + target.addDynamicallyLegalOp( + [](DuplicateBlockArgsOp op) { return op.getIsLegal(); }); + // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps;