Skip to content

Commit d0a407c

Browse files
raikonenfnuGroverkss
authored andcommitted
Revert "[mlir][Transforms] Dialect conversion: Build unresolved materialization for replaced ops (llvm#101514)"
This reverts commit 2d50029.
1 parent 065d2d9 commit d0a407c

File tree

4 files changed

+87
-57
lines changed

4 files changed

+87
-57
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,12 +2348,6 @@ struct OperationConverter {
23482348
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
23492349
ConversionPatternRewriterImpl &rewriterImpl);
23502350

2351-
/// Legalize the types of converted op results.
2352-
LogicalResult legalizeConvertedOpResultTypes(
2353-
ConversionPatternRewriter &rewriter,
2354-
ConversionPatternRewriterImpl &rewriterImpl,
2355-
DenseMap<Value, SmallVector<Value>> &inverseMapping);
2356-
23572351
/// Legalize any unresolved type materializations.
23582352
LogicalResult legalizeUnresolvedMaterializations(
23592353
ConversionPatternRewriter &rewriter,
@@ -2365,6 +2359,14 @@ struct OperationConverter {
23652359
legalizeErasedResult(Operation *op, OpResult result,
23662360
ConversionPatternRewriterImpl &rewriterImpl);
23672361

2362+
/// Legalize an operation result that was replaced with a value of a different
2363+
/// type.
2364+
LogicalResult legalizeChangedResultType(
2365+
Operation *op, OpResult result, Value newValue,
2366+
const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2367+
ConversionPatternRewriterImpl &rewriterImpl,
2368+
const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2369+
23682370
/// Dialect conversion configuration.
23692371
ConversionConfig config;
23702372

@@ -2457,42 +2459,10 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
24572459
return failure();
24582460
DenseMap<Value, SmallVector<Value>> inverseMapping =
24592461
rewriterImpl.mapping.getInverse();
2460-
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
2461-
inverseMapping)))
2462-
return failure();
24632462
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
24642463
inverseMapping)))
24652464
return failure();
2466-
return success();
2467-
}
2468-
2469-
/// Finds a user of the given value, or of any other value that the given value
2470-
/// replaced, that was not replaced in the conversion process.
2471-
static Operation *findLiveUserOfReplaced(
2472-
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2473-
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2474-
SmallVector<Value> worklist = {initialValue};
2475-
while (!worklist.empty()) {
2476-
Value value = worklist.pop_back_val();
2477-
2478-
// Walk the users of this value to see if there are any live users that
2479-
// weren't replaced during conversion.
2480-
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2481-
return rewriterImpl.isOpIgnored(user);
2482-
});
2483-
if (liveUserIt != value.user_end())
2484-
return *liveUserIt;
2485-
auto mapIt = inverseMapping.find(value);
2486-
if (mapIt != inverseMapping.end())
2487-
worklist.append(mapIt->second);
2488-
}
2489-
return nullptr;
2490-
}
24912465

2492-
LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
2493-
ConversionPatternRewriter &rewriter,
2494-
ConversionPatternRewriterImpl &rewriterImpl,
2495-
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
24962466
// Process requested operation replacements.
24972467
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
24982468
auto *opReplacement =
@@ -2515,21 +2485,14 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
25152485
if (result.getType() == newValue.getType())
25162486
continue;
25172487

2518-
Operation *liveUser =
2519-
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2520-
if (!liveUser)
2521-
continue;
2522-
25232488
// Legalize this result.
2524-
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2525-
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
2526-
/*inputs=*/newValue, /*outputType=*/result.getType(),
2527-
opReplacement->getConverter());
2528-
rewriterImpl.mapping.map(result, castValue);
2529-
inverseMapping[castValue].push_back(result);
2489+
rewriter.setInsertionPoint(op);
2490+
if (failed(legalizeChangedResultType(
2491+
op, result, newValue, opReplacement->getConverter(), rewriter,
2492+
rewriterImpl, inverseMapping)))
2493+
return failure();
25302494
}
25312495
}
2532-
25332496
return success();
25342497
}
25352498

@@ -2539,7 +2502,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25392502
// Functor used to check if all users of a value will be dead after
25402503
// conversion.
25412504
// TODO: This should probably query the inverse mapping, same as in
2542-
// `legalizeConvertedOpResultTypes`.
2505+
// `legalizeChangedResultType`.
25432506
auto findLiveUser = [&](Value val) {
25442507
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
25452508
return rewriterImpl.isOpIgnored(user);
@@ -2869,6 +2832,67 @@ LogicalResult OperationConverter::legalizeErasedResult(
28692832
return success();
28702833
}
28712834

2835+
/// Finds a user of the given value, or of any other value that the given value
2836+
/// replaced, that was not replaced in the conversion process.
2837+
static Operation *findLiveUserOfReplaced(
2838+
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2839+
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2840+
SmallVector<Value> worklist(1, initialValue);
2841+
while (!worklist.empty()) {
2842+
Value value = worklist.pop_back_val();
2843+
2844+
// Walk the users of this value to see if there are any live users that
2845+
// weren't replaced during conversion.
2846+
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2847+
return rewriterImpl.isOpIgnored(user);
2848+
});
2849+
if (liveUserIt != value.user_end())
2850+
return *liveUserIt;
2851+
auto mapIt = inverseMapping.find(value);
2852+
if (mapIt != inverseMapping.end())
2853+
worklist.append(mapIt->second);
2854+
}
2855+
return nullptr;
2856+
}
2857+
2858+
LogicalResult OperationConverter::legalizeChangedResultType(
2859+
Operation *op, OpResult result, Value newValue,
2860+
const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2861+
ConversionPatternRewriterImpl &rewriterImpl,
2862+
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2863+
Operation *liveUser =
2864+
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2865+
if (!liveUser)
2866+
return success();
2867+
2868+
// Functor used to emit a conversion error for a failed materialization.
2869+
auto emitConversionError = [&] {
2870+
InFlightDiagnostic diag = op->emitError()
2871+
<< "failed to materialize conversion for result #"
2872+
<< result.getResultNumber() << " of operation '"
2873+
<< op->getName()
2874+
<< "' that remained live after conversion";
2875+
diag.attachNote(liveUser->getLoc())
2876+
<< "see existing live user here: " << *liveUser;
2877+
return failure();
2878+
};
2879+
2880+
// If the replacement has a type converter, attempt to materialize a
2881+
// conversion back to the original type.
2882+
if (!replConverter)
2883+
return emitConversionError();
2884+
2885+
// Materialize a conversion for this live result value.
2886+
Type resultType = result.getType();
2887+
Value convertedValue = replConverter->materializeSourceConversion(
2888+
rewriter, op->getLoc(), resultType, newValue);
2889+
if (!convertedValue)
2890+
return emitConversionError();
2891+
2892+
rewriterImpl.mapping.map(result, convertedValue);
2893+
return success();
2894+
}
2895+
28722896
//===----------------------------------------------------------------------===//
28732897
// Type Conversion
28742898
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,8 +560,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
560560
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
561561
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
562562
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
563-
// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
564-
// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
563+
// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
564+
// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
565565
// CHECK: return %[[CAST0]], %[[CAST1]]
566566
func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
567567
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>

mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
7878
// memref.cast.
7979
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
8080
%0 = bufferization.to_tensor %m : memref<?xf32>
81-
// expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
81+
// expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
8282
%1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
83+
// expected-note @+1 {{see existing live user here}}
8384
return %1 : memref<?xf32, strided<[1], offset: ?>>
8485
}
8586

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,20 @@ func.func @test_valid_arg_materialization(%arg0: i64) {
2020
// -----
2121

2222
func.func @test_invalid_result_materialization() {
23-
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
23+
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
2424
%result = "test.type_producer"() : () -> f16
25+
26+
// expected-note@below {{see existing live user here}}
2527
"foo.return"(%result) : (f16) -> ()
2628
}
2729

2830
// -----
2931

3032
func.func @test_invalid_result_materialization() {
31-
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
33+
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
3234
%result = "test.type_producer"() : () -> f16
35+
36+
// expected-note@below {{see existing live user here}}
3337
"foo.return"(%result) : (f16) -> ()
3438
}
3539

@@ -47,8 +51,9 @@ func.func @test_transitive_use_materialization() {
4751
// -----
4852

4953
func.func @test_transitive_use_invalid_materialization() {
50-
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
54+
// expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
5155
%result = "test.another_type_producer"() : () -> f16
56+
// expected-note@below {{see existing live user here}}
5257
"foo.return"(%result) : (f16) -> ()
5358
}
5459

0 commit comments

Comments
 (0)