diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 8f9b21b7ee1e5..11e593cebc09b 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2348,6 +2348,12 @@ struct OperationConverter { legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl); + /// Legalize the types of converted op results. + LogicalResult legalizeConvertedOpResultTypes( + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl, + DenseMap> &inverseMapping); + /// Legalize any unresolved type materializations. LogicalResult legalizeUnresolvedMaterializations( ConversionPatternRewriter &rewriter, @@ -2359,14 +2365,6 @@ struct OperationConverter { legalizeErasedResult(Operation *op, OpResult result, ConversionPatternRewriterImpl &rewriterImpl); - /// Legalize an operation result that was replaced with a value of a different - /// type. - LogicalResult legalizeChangedResultType( - Operation *op, OpResult result, Value newValue, - const TypeConverter *replConverter, ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - const DenseMap> &inverseMapping); - /// Dialect conversion configuration. ConversionConfig config; @@ -2459,10 +2457,42 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) { return failure(); DenseMap> inverseMapping = rewriterImpl.mapping.getInverse(); + if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl, + inverseMapping))) + return failure(); if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl, inverseMapping))) return failure(); + return success(); +} + +/// Finds a user of the given value, or of any other value that the given value +/// replaced, that was not replaced in the conversion process. +static Operation *findLiveUserOfReplaced( + Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, + const DenseMap> &inverseMapping) { + SmallVector worklist = {initialValue}; + while (!worklist.empty()) { + Value value = worklist.pop_back_val(); + + // Walk the users of this value to see if there are any live users that + // weren't replaced during conversion. + auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) { + return rewriterImpl.isOpIgnored(user); + }); + if (liveUserIt != value.user_end()) + return *liveUserIt; + auto mapIt = inverseMapping.find(value); + if (mapIt != inverseMapping.end()) + worklist.append(mapIt->second); + } + return nullptr; +} +LogicalResult OperationConverter::legalizeConvertedOpResultTypes( + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl, + DenseMap> &inverseMapping) { // Process requested operation replacements. for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) { auto *opReplacement = @@ -2485,14 +2515,21 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) { if (result.getType() == newValue.getType()) continue; + Operation *liveUser = + findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); + if (!liveUser) + continue; + // Legalize this result. - rewriter.setInsertionPoint(op); - if (failed(legalizeChangedResultType( - op, result, newValue, opReplacement->getConverter(), rewriter, - rewriterImpl, inverseMapping))) - return failure(); + Value castValue = rewriterImpl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(result), op->getLoc(), + /*inputs=*/newValue, /*outputType=*/result.getType(), + opReplacement->getConverter()); + rewriterImpl.mapping.map(result, castValue); + inverseMapping[castValue].push_back(result); } } + return success(); } @@ -2502,7 +2539,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( // Functor used to check if all users of a value will be dead after // conversion. // TODO: This should probably query the inverse mapping, same as in - // `legalizeChangedResultType`. + // `legalizeConvertedOpResultTypes`. auto findLiveUser = [&](Value val) { auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) { return rewriterImpl.isOpIgnored(user); @@ -2832,67 +2869,6 @@ LogicalResult OperationConverter::legalizeErasedResult( return success(); } -/// Finds a user of the given value, or of any other value that the given value -/// replaced, that was not replaced in the conversion process. -static Operation *findLiveUserOfReplaced( - Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, - const DenseMap> &inverseMapping) { - SmallVector worklist(1, initialValue); - while (!worklist.empty()) { - Value value = worklist.pop_back_val(); - - // Walk the users of this value to see if there are any live users that - // weren't replaced during conversion. - auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) { - return rewriterImpl.isOpIgnored(user); - }); - if (liveUserIt != value.user_end()) - return *liveUserIt; - auto mapIt = inverseMapping.find(value); - if (mapIt != inverseMapping.end()) - worklist.append(mapIt->second); - } - return nullptr; -} - -LogicalResult OperationConverter::legalizeChangedResultType( - Operation *op, OpResult result, Value newValue, - const TypeConverter *replConverter, ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - const DenseMap> &inverseMapping) { - Operation *liveUser = - findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); - if (!liveUser) - return success(); - - // Functor used to emit a conversion error for a failed materialization. - auto emitConversionError = [&] { - InFlightDiagnostic diag = op->emitError() - << "failed to materialize conversion for result #" - << result.getResultNumber() << " of operation '" - << op->getName() - << "' that remained live after conversion"; - diag.attachNote(liveUser->getLoc()) - << "see existing live user here: " << *liveUser; - return failure(); - }; - - // If the replacement has a type converter, attempt to materialize a - // conversion back to the original type. - if (!replConverter) - return emitConversionError(); - - // Materialize a conversion for this live result value. - Type resultType = result.getType(); - Value convertedValue = replConverter->materializeSourceConversion( - rewriter, op->getLoc(), resultType, newValue); - if (!convertedValue) - return emitConversionError(); - - rewriterImpl.mapping.map(result, convertedValue); - return success(); -} - //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index dd0ed77470a25..d8570bdaf4247 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -560,8 +560,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) { // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>) // CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32> // CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32> -// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32> -// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32> +// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32> +// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32> // CHECK: return %[[CAST0]], %[[CAST1]] func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) { %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32> diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir index ff94c1b331d92..a192434c5accf 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir @@ -78,9 +78,8 @@ func.func @static_layout_to_no_layout_cast(%m: memref) -> memref> { %0 = bufferization.to_tensor %m : memref - // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}} + // expected-error @+1 {{failed to legalize unresolved materialization from ('memref') to 'memref>' that remained live after conversion}} %1 = bufferization.to_memref %0 : memref> - // expected-note @+1 {{see existing live user here}} return %1 : memref> } diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir index d0563fed8e5d9..252b990210a18 100644 --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -20,20 +20,16 @@ func.func @test_valid_arg_materialization(%arg0: i64) { // ----- func.func @test_invalid_result_materialization() { - // expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}} + // expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}} %result = "test.type_producer"() : () -> f16 - - // expected-note@below {{see existing live user here}} "foo.return"(%result) : (f16) -> () } // ----- func.func @test_invalid_result_materialization() { - // expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}} + // expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}} %result = "test.type_producer"() : () -> f16 - - // expected-note@below {{see existing live user here}} "foo.return"(%result) : (f16) -> () } @@ -51,9 +47,8 @@ func.func @test_transitive_use_materialization() { // ----- func.func @test_transitive_use_invalid_materialization() { - // expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}} + // expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}} %result = "test.another_type_producer"() : () -> f16 - // expected-note@below {{see existing live user here}} "foo.return"(%result) : (f16) -> () }