Skip to content

[mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement #117513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 86 additions & 44 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});

// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

// Helper function that checks if the given value range is a bare pointer.
auto isBarePointer = [](ValueRange values) {
return values.size() == 1 &&
isa<LLVM::LLVMPointerType>(values.front().getType());
};

// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type. The dialect conversion framework will then
// insert a target materialization from the original block argument type to
// a legal type.
addArgumentMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
// TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
// must be passed explicitly.
auto packUnrankedMemRefDesc =
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
Location loc, LLVMTypeConverter &converter) -> Value {
// Note: Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
return Value();
Value desc =
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
inputs);
};

// MemRef descriptor elements -> UnrankedMemRefType
auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
// An argument materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
Value desc;
if (isBarePointer(inputs)) {
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
} else if (TypeRange(inputs) ==
getMemRefDescriptorFields(resultType,
/*unpackAggregates=*/true)) {
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
} else {
// The inputs are neither a bare pointer nor an unpacked memref
// descriptor. This materialization function cannot be used.
Value packed =
packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
if (!packed)
return Value();
}
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
.getResult(0);
};

// TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
// must be passed explicitly.
auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc,
LLVMTypeConverter &converter) -> Value {
assert(resultType && "expected non-null result type");
if (isBarePointer(inputs))
return MemRefDescriptor::fromStaticShape(builder, loc, converter,
resultType, inputs[0]);
if (TypeRange(inputs) ==
converter.getMemRefDescriptorFields(resultType,
/*unpackAggregates=*/true))
return MemRefDescriptor::pack(builder, loc, converter, resultType,
inputs);
// The inputs are neither a bare pointer nor an unpacked memref descriptor.
// This materialization function cannot be used.
return Value();
};

// MemRef descriptor elements -> MemRefType
auto rankedMemRefMaterialization = [&](OpBuilder &builder,
MemRefType resultType,
ValueRange inputs, Location loc) {
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
if (inputs.size() != 1)
Value packed =
packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
if (!packed)
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
.getResult(0);
});
};

// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type.
addArgumentMaterialization(unrakedMemRefMaterialization);
addArgumentMaterialization(rankedMemRefMaterialization);
addSourceMaterialization(unrakedMemRefMaterialization);
addSourceMaterialization(rankedMemRefMaterialization);

// Bare pointer -> Packed MemRef descriptor
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
if (inputs.size() != 1)
ValueRange inputs, Location loc,
Type originalType) -> Value {
// The original MemRef type is required to build a MemRef descriptor
// because the sizes/strides of the MemRef cannot be inferred from just the
// bare pointer.
if (!originalType)
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
if (resultType != convertType(originalType))
return Value();
if (auto memrefType = dyn_cast<MemRefType>(originalType))
return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
*this);
return Value();
});

// Integer memory spaces map to themselves.
Expand Down
46 changes: 14 additions & 32 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// function will be deleted when full 1:N support has been added.
///
/// This function inserts an argument materialization back to the original
/// type, followed by a target materialization to the legalized type (if
/// applicable).
/// type.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
Expand Down Expand Up @@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
insertNTo1Materialization(
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
if (replArgs.size() == 1) {
mapping.map(origArg, replArgs.front());
} else {
insertNTo1Materialization(
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
}
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}

Expand Down Expand Up @@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
// Insert argument materialization back to the original type.
Type originalType = originalValue.getType();
UnrealizedConversionCastOp argCastOp;
Value argMat = buildUnresolvedMaterialization(
buildUnresolvedMaterialization(
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
/*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
&argCastOp);
/*inputs=*/replacements, originalType,
/*originalType=*/Type(), converter, &argCastOp);
if (argCastOp)
nTo1TempMaterializations.insert(argCastOp);

// Insert target materialization to the legalized type.
Type legalOutputType;
if (converter) {
legalOutputType = converter->convertType(originalType);
} else if (replacements.size() == 1) {
// When there is no type converter, assume that the replacement value
// types are legal. This is reasonable to assume because they were
// specified by the user.
// FIXME: This won't work for 1->N conversions because multiple output
// types are not supported in parts of the dialect conversion. In such a
// case, we currently use the original value type.
legalOutputType = replacements[0].getType();
}
if (legalOutputType && legalOutputType != originalType) {
UnrealizedConversionCastOp targetCastOp;
buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(argMat), loc,
/*valueToMap=*/argMat, /*inputs=*/argMat,
/*outputType=*/legalOutputType, /*originalType=*/originalType,
converter, &targetCastOp);
if (targetCastOp)
nTo1TempMaterializations.insert(targetCastOp);
}
}

Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Expand Down Expand Up @@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,

LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
assert(this && "expected non-null type converter");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(this && "expected non-null type converter");

This assert only protects against a case that is already in UB as far as C++ is concerned and would e.g. get optimized out in release builds.

Copy link
Member Author

@matthias-springer matthias-springer Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is not really needed, but I spent an hour debugging patterns that run without type converters (running in debug mode). (Can also be helpful for finding bugs in the dialect conversion framework.) When convertType is called on a null type converter, the stack trace looks really odd and makes it look like something is wrong with the caching logic / mutex synchronization inside of TypeConverter::convertType.

assert(t && "expected non-null type");

{
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
std::defer_lock);
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ func.func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
// expected-remark@+1 {{op 'foo.region' is not legalizable}}
"foo.region"() ({
// CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
^bb0(%i0: i64, %unused: i16, %i1: i64):
// CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
"test.invalid"(%i0, %i1) : (i64, i64) -> ()
// CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
^bb0(%i0: f64, %unused: i16, %i1: f64):
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
"test.invalid"(%i0, %i1) : (f64, f64) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
Expand Down
47 changes: 24 additions & 23 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,8 +985,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
};
/// This pattern simply updates the operands of the given operation.
struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp(MLIRContext *ctx)
: ConversionPattern("test.invalid", 1, ctx) {}
TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
: ConversionPattern(converter, "test.invalid", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
Expand Down Expand Up @@ -1307,19 +1307,19 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
patterns.add<
TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
TestSplitReturnType, TestChangeProducerTypeI32ToF32,
TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
TestUpdateConsumerType, TestNonRootReplacement,
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
&getContext(), converter);
patterns
.add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp>(&getContext(), converter);
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
Expand Down Expand Up @@ -1755,8 +1755,9 @@ struct TestTypeConversionAnotherProducer
};

struct TestReplaceWithLegalOp : public ConversionPattern {
TestReplaceWithLegalOp(MLIRContext *ctx)
: ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
: ConversionPattern(converter, "test.replace_with_legal_op",
/*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Expand Down Expand Up @@ -1878,12 +1879,12 @@ struct TestTypeConversionDriver

// Initialize the set of rewrite patterns.
RewritePatternSet patterns(&getContext());
patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
TestSignatureConversionUndo,
TestTestSignatureConversionNoConverter>(converter,
&getContext());
patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
&getContext());
patterns
.add<TestTypeConsumerForward, TestTypeConversionProducer,
TestSignatureConversionUndo,
TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
converter, &getContext());
patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);

Expand Down
Loading