diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 59b0f5c9b09bc..e2ab0ed6f66cc 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, type.isVarArg()); }); + // Add generic source and target materializations to handle cases where + // non-LLVM types persist after an LLVM conversion. + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + // Helper function that checks if the given value range is a bare pointer. auto isBarePointer = [](ValueRange values) { return values.size() == 1 && isa(values.front().getType()); }; - // Argument materializations convert from the new block argument types - // (multiple SSA values that make up a memref descriptor) back to the - // original block argument type. The dialect conversion framework will then - // insert a target materialization from the original block argument type to - // a legal type. - addArgumentMaterialization([&](OpBuilder &builder, - UnrankedMemRefType resultType, - ValueRange inputs, Location loc) { + // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter + // must be passed explicitly. + auto packUnrankedMemRefDesc = + [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, + Location loc, LLVMTypeConverter &converter) -> Value { // Note: Bare pointers are not supported for unranked memrefs because a // memref descriptor cannot be built just from a bare pointer. - if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields()) + if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields()) return Value(); - Value desc = - UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs); + return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType, + inputs); + }; + + // MemRef descriptor elements -> UnrankedMemRefType + auto unrakedMemRefMaterialization = [&](OpBuilder &builder, + UnrankedMemRefType resultType, + ValueRange inputs, Location loc) { // An argument materialization must return a value of type // `resultType`, so insert a cast from the memref descriptor type // (!llvm.struct) to the original memref type. - return builder.create(loc, resultType, desc) - .getResult(0); - }); - addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, - ValueRange inputs, Location loc) { - Value desc; - if (isBarePointer(inputs)) { - desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType, - inputs[0]); - } else if (TypeRange(inputs) == - getMemRefDescriptorFields(resultType, - /*unpackAggregates=*/true)) { - desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); - } else { - // The inputs are neither a bare pointer nor an unpacked memref - // descriptor. This materialization function cannot be used. + Value packed = + packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this); + if (!packed) return Value(); - } + return builder.create(loc, resultType, packed) + .getResult(0); + }; + + // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter + // must be passed explicitly. + auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType, + ValueRange inputs, Location loc, + LLVMTypeConverter &converter) -> Value { + assert(resultType && "expected non-null result type"); + if (isBarePointer(inputs)) + return MemRefDescriptor::fromStaticShape(builder, loc, converter, + resultType, inputs[0]); + if (TypeRange(inputs) == + converter.getMemRefDescriptorFields(resultType, + /*unpackAggregates=*/true)) + return MemRefDescriptor::pack(builder, loc, converter, resultType, + inputs); + // The inputs are neither a bare pointer nor an unpacked memref descriptor. + // This materialization function cannot be used. + return Value(); + }; + + // MemRef descriptor elements -> MemRefType + auto rankedMemRefMaterialization = [&](OpBuilder &builder, + MemRefType resultType, + ValueRange inputs, Location loc) { // An argument materialization must return a value of type `resultType`, // so insert a cast from the memref descriptor type (!llvm.struct) to the // original memref type. - return builder.create(loc, resultType, desc) - .getResult(0); - }); - // Add generic source and target materializations to handle cases where - // non-LLVM types persist after an LLVM conversion. - addSourceMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) { - if (inputs.size() != 1) + Value packed = + packRankedMemRefDesc(builder, resultType, inputs, loc, *this); + if (!packed) return Value(); - - return builder.create(loc, resultType, inputs) + return builder.create(loc, resultType, packed) .getResult(0); - }); + }; + + // Argument materializations convert from the new block argument types + // (multiple SSA values that make up a memref descriptor) back to the + // original block argument type. + addArgumentMaterialization(unrakedMemRefMaterialization); + addArgumentMaterialization(rankedMemRefMaterialization); + addSourceMaterialization(unrakedMemRefMaterialization); + addSourceMaterialization(rankedMemRefMaterialization); + + // Bare pointer -> Packed MemRef descriptor addTargetMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) { - if (inputs.size() != 1) + ValueRange inputs, Location loc, + Type originalType) -> Value { + // The original MemRef type is required to build a MemRef descriptor + // because the sizes/strides of the MemRef cannot be inferred from just the + // bare pointer. + if (!originalType) return Value(); - - return builder.create(loc, resultType, inputs) - .getResult(0); + if (resultType != convertType(originalType)) + return Value(); + if (auto memrefType = dyn_cast(originalType)) + return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this); + if (auto unrankedMemrefType = dyn_cast(originalType)) + return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc, + *this); + return Value(); }); // Integer memory spaces map to themselves. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 1607740a1ee07..51686646a0a2f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// function will be deleted when full 1:N support has been added. /// /// This function inserts an argument materialization back to the original - /// type, followed by a target materialization to the legalized type (if - /// applicable). + /// type. void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, Value originalValue, const TypeConverter *converter); @@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // used as a replacement. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - insertNTo1Materialization( - OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*replacements=*/replArgs, /*outputValue=*/origArg, converter); + if (replArgs.size() == 1) { + mapping.map(origArg, replArgs.front()); + } else { + insertNTo1Materialization( + OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), + /*replacements=*/replArgs, /*outputValue=*/origArg, converter); + } appendRewrite(block, origArg, converter); } @@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization( // Insert argument materialization back to the original type. Type originalType = originalValue.getType(); UnrealizedConversionCastOp argCastOp; - Value argMat = buildUnresolvedMaterialization( + buildUnresolvedMaterialization( MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue, - /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter, - &argCastOp); + /*inputs=*/replacements, originalType, + /*originalType=*/Type(), converter, &argCastOp); if (argCastOp) nTo1TempMaterializations.insert(argCastOp); - - // Insert target materialization to the legalized type. - Type legalOutputType; - if (converter) { - legalOutputType = converter->convertType(originalType); - } else if (replacements.size() == 1) { - // When there is no type converter, assume that the replacement value - // types are legal. This is reasonable to assume because they were - // specified by the user. - // FIXME: This won't work for 1->N conversions because multiple output - // types are not supported in parts of the dialect conversion. In such a - // case, we currently use the original value type. - legalOutputType = replacements[0].getType(); - } - if (legalOutputType && legalOutputType != originalType) { - UnrealizedConversionCastOp targetCastOp; - buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(argMat), loc, - /*valueToMap=*/argMat, /*inputs=*/argMat, - /*outputType=*/legalOutputType, /*originalType=*/originalType, - converter, &targetCastOp); - if (targetCastOp) - nTo1TempMaterializations.insert(targetCastOp); - } } Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( @@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, LogicalResult TypeConverter::convertType(Type t, SmallVectorImpl &results) const { + assert(this && "expected non-null type converter"); + assert(t && "expected non-null type"); + { std::shared_lock cacheReadLock(cacheMutex, std::defer_lock); diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index d98a6a036e6b1..2ca5f49637523 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -124,10 +124,10 @@ func.func @no_remap_nested() { // CHECK-NEXT: "foo.region" // expected-remark@+1 {{op 'foo.region' is not legalizable}} "foo.region"() ({ - // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64): - ^bb0(%i0: i64, %unused: i16, %i1: i64): - // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64) - "test.invalid"(%i0, %i1) : (i64, i64) -> () + // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64): + ^bb0(%i0: f64, %unused: i16, %i1: f64): + // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64) + "test.invalid"(%i0, %i1) : (f64, f64) -> () }) : () -> () // expected-remark@+1 {{op 'func.return' is not legalizable}} return diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index ce2820b80a945..a470497fdbb56 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -985,8 +985,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern { }; /// This pattern simply updates the operands of the given operation. struct TestPassthroughInvalidOp : public ConversionPattern { - TestPassthroughInvalidOp(MLIRContext *ctx) - : ConversionPattern("test.invalid", 1, ctx) {} + TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.invalid", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -1307,19 +1307,19 @@ struct TestLegalizePatternDriver TestTypeConverter converter; mlir::RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); - patterns.add< - TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, - TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, - TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp, - TestSplitReturnType, TestChangeProducerTypeI32ToF32, - TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, - TestUpdateConsumerType, TestNonRootReplacement, - TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, - TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, - TestUndoPropertiesModification, TestEraseOp, - TestRepetitive1ToNConsumer>(&getContext()); - patterns.add( - &getContext(), converter); + patterns + .add(&getContext()); + patterns.add(&getContext(), converter); patterns.add(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1755,8 +1755,9 @@ struct TestTypeConversionAnotherProducer }; struct TestReplaceWithLegalOp : public ConversionPattern { - TestReplaceWithLegalOp(MLIRContext *ctx) - : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {} + TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx) + : ConversionPattern(converter, "test.replace_with_legal_op", + /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -1878,12 +1879,12 @@ struct TestTypeConversionDriver // Initialize the set of rewrite patterns. RewritePatternSet patterns(&getContext()); - patterns.add(converter, - &getContext()); - patterns.add( - &getContext()); + patterns + .add( + converter, &getContext()); + patterns.add(&getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter);