Skip to content

Commit 2d9d128

Browse files
[mlir][Transforms] Dialect conversion: Build unresolved materialization for replaced ops
When inserting an argument/source/target materialization, the dialect conversion framework first inserts a "dummy" `unrealized_conversion_cast` op (during the rewrite process) and then (in the "finialize" phase) replaces these cast ops with the IR generated by the type converter callback. This is the case for all materializations, except when ops are being replaced with values that have a different type. In that case, the dialect conversion currently directly emits a source materialization. This commit changes the implementation, such that a temporary `unrealized_conversion_cast` is also inserted in this case. This commit simplifies the code base: all materializations now happen in `legalizeUnresolvedMaterialization`. This commit makes it possible to decouple source/target/argument materializations from the dialect conversion (to reduce the complexity of the code base). Such materializations can then also be optional. This will be implemented in a follow-up commit.
1 parent 82afd9d commit 2d9d128

File tree

4 files changed

+57
-87
lines changed

4 files changed

+57
-87
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 51 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,6 +2348,12 @@ struct OperationConverter {
23482348
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
23492349
ConversionPatternRewriterImpl &rewriterImpl);
23502350

2351+
/// Legalize the types of converted op results.
2352+
LogicalResult legalizeConvertedOpResultTypes(
2353+
ConversionPatternRewriter &rewriter,
2354+
ConversionPatternRewriterImpl &rewriterImpl,
2355+
DenseMap<Value, SmallVector<Value>> &inverseMapping);
2356+
23512357
/// Legalize any unresolved type materializations.
23522358
LogicalResult legalizeUnresolvedMaterializations(
23532359
ConversionPatternRewriter &rewriter,
@@ -2359,14 +2365,6 @@ struct OperationConverter {
23592365
legalizeErasedResult(Operation *op, OpResult result,
23602366
ConversionPatternRewriterImpl &rewriterImpl);
23612367

2362-
/// Legalize an operation result that was replaced with a value of a different
2363-
/// type.
2364-
LogicalResult legalizeChangedResultType(
2365-
Operation *op, OpResult result, Value newValue,
2366-
const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2367-
ConversionPatternRewriterImpl &rewriterImpl,
2368-
const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2369-
23702368
/// Dialect conversion configuration.
23712369
ConversionConfig config;
23722370

@@ -2459,10 +2457,42 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
24592457
return failure();
24602458
DenseMap<Value, SmallVector<Value>> inverseMapping =
24612459
rewriterImpl.mapping.getInverse();
2460+
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
2461+
inverseMapping)))
2462+
return failure();
24622463
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
24632464
inverseMapping)))
24642465
return failure();
2466+
return success();
2467+
}
24652468

2469+
/// Finds a user of the given value, or of any other value that the given value
2470+
/// replaced, that was not replaced in the conversion process.
2471+
static Operation *findLiveUserOfReplaced(
2472+
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2473+
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2474+
SmallVector<Value> worklist(1, initialValue);
2475+
while (!worklist.empty()) {
2476+
Value value = worklist.pop_back_val();
2477+
2478+
// Walk the users of this value to see if there are any live users that
2479+
// weren't replaced during conversion.
2480+
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2481+
return rewriterImpl.isOpIgnored(user);
2482+
});
2483+
if (liveUserIt != value.user_end())
2484+
return *liveUserIt;
2485+
auto mapIt = inverseMapping.find(value);
2486+
if (mapIt != inverseMapping.end())
2487+
worklist.append(mapIt->second);
2488+
}
2489+
return nullptr;
2490+
}
2491+
2492+
LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
2493+
ConversionPatternRewriter &rewriter,
2494+
ConversionPatternRewriterImpl &rewriterImpl,
2495+
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
24662496
// Process requested operation replacements.
24672497
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
24682498
auto *opReplacement =
@@ -2485,14 +2515,21 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
24852515
if (result.getType() == newValue.getType())
24862516
continue;
24872517

2518+
Operation *liveUser =
2519+
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2520+
if (!liveUser)
2521+
continue;
2522+
24882523
// Legalize this result.
2489-
rewriter.setInsertionPoint(op);
2490-
if (failed(legalizeChangedResultType(
2491-
op, result, newValue, opReplacement->getConverter(), rewriter,
2492-
rewriterImpl, inverseMapping)))
2493-
return failure();
2524+
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2525+
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
2526+
/*inputs=*/newValue, /*outputType=*/result.getType(),
2527+
opReplacement->getConverter());
2528+
rewriterImpl.mapping.map(result, castValue);
2529+
inverseMapping[castValue].push_back(result);
24942530
}
24952531
}
2532+
24962533
return success();
24972534
}
24982535

@@ -2502,7 +2539,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25022539
// Functor used to check if all users of a value will be dead after
25032540
// conversion.
25042541
// TODO: This should probably query the inverse mapping, same as in
2505-
// `legalizeChangedResultType`.
2542+
// `legalizeConvertedOpResultTypes`.
25062543
auto findLiveUser = [&](Value val) {
25072544
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
25082545
return rewriterImpl.isOpIgnored(user);
@@ -2832,67 +2869,6 @@ LogicalResult OperationConverter::legalizeErasedResult(
28322869
return success();
28332870
}
28342871

2835-
/// Finds a user of the given value, or of any other value that the given value
2836-
/// replaced, that was not replaced in the conversion process.
2837-
static Operation *findLiveUserOfReplaced(
2838-
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2839-
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2840-
SmallVector<Value> worklist(1, initialValue);
2841-
while (!worklist.empty()) {
2842-
Value value = worklist.pop_back_val();
2843-
2844-
// Walk the users of this value to see if there are any live users that
2845-
// weren't replaced during conversion.
2846-
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2847-
return rewriterImpl.isOpIgnored(user);
2848-
});
2849-
if (liveUserIt != value.user_end())
2850-
return *liveUserIt;
2851-
auto mapIt = inverseMapping.find(value);
2852-
if (mapIt != inverseMapping.end())
2853-
worklist.append(mapIt->second);
2854-
}
2855-
return nullptr;
2856-
}
2857-
2858-
LogicalResult OperationConverter::legalizeChangedResultType(
2859-
Operation *op, OpResult result, Value newValue,
2860-
const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2861-
ConversionPatternRewriterImpl &rewriterImpl,
2862-
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2863-
Operation *liveUser =
2864-
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2865-
if (!liveUser)
2866-
return success();
2867-
2868-
// Functor used to emit a conversion error for a failed materialization.
2869-
auto emitConversionError = [&] {
2870-
InFlightDiagnostic diag = op->emitError()
2871-
<< "failed to materialize conversion for result #"
2872-
<< result.getResultNumber() << " of operation '"
2873-
<< op->getName()
2874-
<< "' that remained live after conversion";
2875-
diag.attachNote(liveUser->getLoc())
2876-
<< "see existing live user here: " << *liveUser;
2877-
return failure();
2878-
};
2879-
2880-
// If the replacement has a type converter, attempt to materialize a
2881-
// conversion back to the original type.
2882-
if (!replConverter)
2883-
return emitConversionError();
2884-
2885-
// Materialize a conversion for this live result value.
2886-
Type resultType = result.getType();
2887-
Value convertedValue = replConverter->materializeSourceConversion(
2888-
rewriter, op->getLoc(), resultType, newValue);
2889-
if (!convertedValue)
2890-
return emitConversionError();
2891-
2892-
rewriterImpl.mapping.map(result, convertedValue);
2893-
return success();
2894-
}
2895-
28962872
//===----------------------------------------------------------------------===//
28972873
// Type Conversion
28982874
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,8 +560,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
560560
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
561561
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
562562
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
563-
// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
564-
// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
563+
// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
564+
// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
565565
// CHECK: return %[[CAST0]], %[[CAST1]]
566566
func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
567567
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>

mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,8 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
7878
// memref.cast.
7979
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
8080
%0 = bufferization.to_tensor %m : memref<?xf32>
81-
// expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
81+
// expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
8282
%1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
83-
// expected-note @+1 {{see existing live user here}}
8483
return %1 : memref<?xf32, strided<[1], offset: ?>>
8584
}
8685

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,16 @@ func.func @test_valid_arg_materialization(%arg0: i64) {
2020
// -----
2121

2222
func.func @test_invalid_result_materialization() {
23-
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
23+
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
2424
%result = "test.type_producer"() : () -> f16
25-
26-
// expected-note@below {{see existing live user here}}
2725
"foo.return"(%result) : (f16) -> ()
2826
}
2927

3028
// -----
3129

3230
func.func @test_invalid_result_materialization() {
33-
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
31+
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
3432
%result = "test.type_producer"() : () -> f16
35-
36-
// expected-note@below {{see existing live user here}}
3733
"foo.return"(%result) : (f16) -> ()
3834
}
3935

@@ -51,9 +47,8 @@ func.func @test_transitive_use_materialization() {
5147
// -----
5248

5349
func.func @test_transitive_use_invalid_materialization() {
54-
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
50+
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
5551
%result = "test.another_type_producer"() : () -> f16
56-
// expected-note@below {{see existing live user here}}
5752
"foo.return"(%result) : (f16) -> ()
5853
}
5954

0 commit comments

Comments
 (0)