Skip to content

Commit d2d3eb9

Browse files
[mlir][Transforms] Dialect Conversion: Do not build target mat. during 1:N replacement
fix test experiement
1 parent 345ca6a commit d2d3eb9

File tree

4 files changed

+83
-78
lines changed

4 files changed

+83
-78
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -153,20 +153,29 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
// Add generic source and target materializations to handle cases where
157+
// non-LLVM types persist after an LLVM conversion.
158+
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
159+
ValueRange inputs, Location loc) {
160+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
161+
.getResult(0);
162+
});
163+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
164+
ValueRange inputs, Location loc) {
165+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
166+
.getResult(0);
167+
});
168+
156169
// Helper function that checks if the given value range is a bare pointer.
157170
auto isBarePointer = [](ValueRange values) {
158171
return values.size() == 1 &&
159172
isa<LLVM::LLVMPointerType>(values.front().getType());
160173
};
161174

162-
// Argument materializations convert from the new block argument types
163-
// (multiple SSA values that make up a memref descriptor) back to the
164-
// original block argument type. The dialect conversion framework will then
165-
// insert a target materialization from the original block argument type to
166-
// a legal type.
167-
addArgumentMaterialization([&](OpBuilder &builder,
168-
UnrankedMemRefType resultType,
169-
ValueRange inputs, Location loc) {
175+
// MemRef descriptor elements -> UnrankedMemRefType
176+
auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
177+
UnrankedMemRefType resultType,
178+
ValueRange inputs, Location loc) {
170179
// Note: Bare pointers are not supported for unranked memrefs because a
171180
// memref descriptor cannot be built just from a bare pointer.
172181
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
@@ -178,9 +187,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
178187
// (!llvm.struct) to the original memref type.
179188
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
180189
.getResult(0);
181-
});
182-
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
183-
ValueRange inputs, Location loc) {
190+
};
191+
192+
// MemRef descriptor elements -> MemRefType
193+
auto rankedMemRefMaterialization = [&](OpBuilder &builder,
194+
MemRefType resultType,
195+
ValueRange inputs, Location loc) {
184196
Value desc;
185197
if (isBarePointer(inputs)) {
186198
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
@@ -199,24 +211,34 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
199211
// original memref type.
200212
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
201213
.getResult(0);
202-
});
203-
// Add generic source and target materializations to handle cases where
204-
// non-LLVM types persist after an LLVM conversion.
205-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
206-
ValueRange inputs, Location loc) {
207-
if (inputs.size() != 1)
208-
return Value();
214+
};
209215

210-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
211-
.getResult(0);
212-
});
216+
// Argument materializations convert from the new block argument types
217+
// (multiple SSA values that make up a memref descriptor) back to the
218+
// original block argument type.
219+
addArgumentMaterialization(unrakedMemRefMaterialization);
220+
addArgumentMaterialization(rankedMemRefMaterialization);
221+
addSourceMaterialization(unrakedMemRefMaterialization);
222+
addSourceMaterialization(rankedMemRefMaterialization);
223+
224+
// Bare pointer -> Packed MemRef descriptor
213225
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214-
ValueRange inputs, Location loc) {
215-
if (inputs.size() != 1)
226+
ValueRange inputs, Location loc,
227+
Type originalType) -> Value {
228+
// The original MemRef type is required to build a MemRef descriptor
229+
// because the sizes/strides of the MemRef cannot be inferred from just the
230+
// bare pointer.
231+
if (!originalType)
216232
return Value();
217-
218-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
219-
.getResult(0);
233+
auto memrefType = dyn_cast<MemRefType>(originalType);
234+
if (!memrefType)
235+
return Value();
236+
if (resultType != convertType(memrefType))
237+
return Value();
238+
if (!isBarePointer(inputs))
239+
return Value();
240+
return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
241+
inputs[0]);
220242
});
221243

222244
// Integer memory spaces map to themselves.

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
830830
/// function will be deleted when full 1:N support has been added.
831831
///
832832
/// This function inserts an argument materialization back to the original
833-
/// type, followed by a target materialization to the legalized type (if
834-
/// applicable).
833+
/// type.
835834
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
836835
ValueRange replacements, Value originalValue,
837836
const TypeConverter *converter);
@@ -1319,9 +1318,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13191318
// used as a replacement.
13201319
auto replArgs =
13211320
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1322-
insertNTo1Materialization(
1323-
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1324-
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1321+
if (replArgs.size() == 1) {
1322+
mapping.map(origArg, replArgs.front());
1323+
} else {
1324+
insertNTo1Materialization(
1325+
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1326+
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1327+
}
13251328
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
13261329
}
13271330

@@ -1372,27 +1375,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
13721375
/*inputs=*/replacements, originalType,
13731376
/*originalType=*/Type(), converter);
13741377
mapping.map(originalValue, argMat);
1375-
1376-
// Insert target materialization to the legalized type.
1377-
Type legalOutputType;
1378-
if (converter) {
1379-
legalOutputType = converter->convertType(originalType);
1380-
} else if (replacements.size() == 1) {
1381-
// When there is no type converter, assume that the replacement value
1382-
// types are legal. This is reasonable to assume because they were
1383-
// specified by the user.
1384-
// FIXME: This won't work for 1->N conversions because multiple output
1385-
// types are not supported in parts of the dialect conversion. In such a
1386-
// case, we currently use the original value type.
1387-
legalOutputType = replacements[0].getType();
1388-
}
1389-
if (legalOutputType && legalOutputType != originalType) {
1390-
Value targetMat = buildUnresolvedMaterialization(
1391-
MaterializationKind::Target, computeInsertPoint(argMat), loc,
1392-
/*inputs=*/argMat, /*outputType=*/legalOutputType,
1393-
/*originalType=*/originalType, converter);
1394-
mapping.map(argMat, targetMat);
1395-
}
13961378
}
13971379

13981380
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
124124
// CHECK-NEXT: "foo.region"
125125
// expected-remark@+1 {{op 'foo.region' is not legalizable}}
126126
"foo.region"() ({
127-
// CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
128-
^bb0(%i0: i64, %unused: i16, %i1: i64):
129-
// CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
130-
"test.invalid"(%i0, %i1) : (i64, i64) -> ()
127+
// CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
128+
^bb0(%i0: f64, %unused: i16, %i1: f64):
129+
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
130+
"test.invalid"(%i0, %i1) : (f64, f64) -> ()
131131
}) : () -> ()
132132
// expected-remark@+1 {{op 'func.return' is not legalizable}}
133133
return

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
979979
};
980980
/// This pattern simply updates the operands of the given operation.
981981
struct TestPassthroughInvalidOp : public ConversionPattern {
982-
TestPassthroughInvalidOp(MLIRContext *ctx)
983-
: ConversionPattern("test.invalid", 1, ctx) {}
982+
TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
983+
: ConversionPattern(converter, "test.invalid", 1, ctx) {}
984984
LogicalResult
985985
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
986986
ConversionPatternRewriter &rewriter) const final {
@@ -1254,18 +1254,18 @@ struct TestLegalizePatternDriver
12541254
TestTypeConverter converter;
12551255
mlir::RewritePatternSet patterns(&getContext());
12561256
populateWithGenerated(patterns);
1257-
patterns.add<
1258-
TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1259-
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1260-
TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1261-
TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1262-
TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1263-
TestUpdateConsumerType, TestNonRootReplacement,
1264-
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1265-
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1266-
TestUndoPropertiesModification, TestEraseOp>(&getContext());
1267-
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1268-
&getContext(), converter);
1257+
patterns
1258+
.add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1259+
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1260+
TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1261+
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1262+
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1263+
TestNonRootReplacement, TestBoundedRecursiveRewrite,
1264+
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1265+
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1266+
TestUndoPropertiesModification, TestEraseOp>(&getContext());
1267+
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1268+
TestPassthroughInvalidOp>(&getContext(), converter);
12691269
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
12701270
converter);
12711271
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1697,8 +1697,9 @@ struct TestTypeConversionAnotherProducer
16971697
};
16981698

16991699
struct TestReplaceWithLegalOp : public ConversionPattern {
1700-
TestReplaceWithLegalOp(MLIRContext *ctx)
1701-
: ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1700+
TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
1701+
: ConversionPattern(converter, "test.replace_with_legal_op",
1702+
/*benefit=*/1, ctx) {}
17021703
LogicalResult
17031704
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
17041705
ConversionPatternRewriter &rewriter) const final {
@@ -1820,12 +1821,12 @@ struct TestTypeConversionDriver
18201821

18211822
// Initialize the set of rewrite patterns.
18221823
RewritePatternSet patterns(&getContext());
1823-
patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1824-
TestSignatureConversionUndo,
1825-
TestTestSignatureConversionNoConverter>(converter,
1826-
&getContext());
1827-
patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1828-
&getContext());
1824+
patterns
1825+
.add<TestTypeConsumerForward, TestTypeConversionProducer,
1826+
TestSignatureConversionUndo,
1827+
TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
1828+
converter, &getContext());
1829+
patterns.add<TestTypeConversionAnotherProducer>(&getContext());
18291830
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
18301831
converter);
18311832

0 commit comments

Comments
 (0)