Skip to content

[mlir][Transforms] Merge 1:1 and 1:N type converters #113032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 25, 2024

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Oct 19, 2024

The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures.

1:1 target materializations (producing a single Value) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a SmallVector<Value>). Internally, all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the OneToNTypeConverter, simply switch all occurrences to TypeConverter.

@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2024

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir-sparse

Author: Matthias Springer (matthias-springer)

Changes

The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures.

1:1 target materializations (producing a single Value) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a SmallVector&lt;Value&gt;). Internally, all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the OneToNTypeConverter, simply switch all occurrences to TypeConverter.

Depends on #113031.


Full diff: https://github.com/llvm/llvm-project/pull/113032.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+41-13)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+1-44)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-4)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+14-30)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+14-4)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
 //===----------------------------------------------------------------------===//
 
 /// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
   SparseIterationTypeConverter();
 };
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..eb7da67c1bb995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
   /// conversion has finished.
   ///
   /// Note: Target materializations may optionally accept an additional Type
-  /// parameter, which is the original type of the SSA value.
+  /// parameter, which is the original type of the SSA value. Furthermore `T`
+  /// can be a TypeRange; in that case, the function must return a
+  /// SmallVector<Value>.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
   /// will be invoked with: outputType = "t3", inputs = "v2",
   // originalType = "t1". Note  that the original type "t1" cannot be recovered
   /// from just "t3" and "v2"; that's why the originalType parameter exists.
+  ///
+  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+  /// that case the materialization produces a SmallVector<Value>.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
                                     Type resultType, ValueRange inputs,
                                     Type originalType = {}) const;
+  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc,
+                                                 TypeRange resultType,
+                                                 ValueRange inputs,
+                                                 Type originalType = {}) const;
 
   /// Convert an attribute present `attr` from within the type `type` using
   /// the registered conversion functions. If no applicable conversion has been
@@ -341,8 +351,8 @@ class TypeConverter {
   /// The signature of the callback used to materialize a target conversion.
   ///
   /// Arguments: builder, result type, inputs, location, original type
-  using TargetMaterializationCallbackFn =
-      std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+      OpBuilder &, TypeRange, ValueRange, Location, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
   /// callback.
   ///
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
       TargetMaterializationCallbackFn>
   wrapTargetMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc, Type originalType) -> Value {
-      if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc, originalType);
-      return Value();
+               OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+               Location loc, Type originalType) -> SmallVector<Value> {
+      SmallVector<Value> result;
+      if constexpr (std::is_same<T, TypeRange>::value) {
+        // This is a 1:N target materialization. Return the produces values
+        // directly.
+        result = callback(builder, resultTypes, inputs, loc, originalType);
+      } else {
+        // This is a 1:1 target materialization. Invoke it only if the result
+        // type class of the callback matches the requested result type.
+        if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+          // 1:1 materializations produce single values, but we store 1:N
+          // target materialization functions in the type converter. Wrap the
+          // result value in a SmallVector<Value>.
+          std::optional<Value> val =
+              callback(builder, derivedType, inputs, loc, originalType);
+          if (val.has_value() && *val)
+            result.push_back(*val);
+        }
+      }
+      return result;
     };
   }
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location)`
+  /// - Value(OpBuilder &, T, ValueRange, Location)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
   wrapTargetMaterialization(FnT &&callback) const {
     return wrapTargetMaterialization<T>(
         [callback = std::forward<FnT>(callback)](
-            OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
-            Type originalType) -> Value {
-          return callback(builder, resultType, inputs, loc);
+            OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+            Type originalType) {
+          return callback(builder, resultTypes, inputs, loc);
         });
   }
 
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
 
 namespace mlir {
 
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
-  /// Callback that expresses user-provided materialization logic from the given
-  /// value to N values of the given types. This is useful for expressing target
-  /// materializations for 1:N type conversions, which materialize one value in
-  /// a source type as N values in target types.
-  using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
-
-  /// Creates the mapping of the given range of original types to target types
-  /// of the conversion and stores that mapping in the given (signature)
-  /// conversion. This function simply calls
-  /// `TypeConverter::convertSignatureArgs` and exists here with a different
-  /// name to reflect the broader semantic.
-  LogicalResult computeTypeMapping(TypeRange types,
-                                   SignatureConversion &result) const {
-    return convertSignatureArgs(types, result);
-  }
-
-  /// Applies one of the user-provided 1:N target materializations. If several
-  /// exists, they are tried out in the reverse order in which they have been
-  /// added until the first one succeeds. If none succeeds, the functions
-  /// returns `std::nullopt`.
-  std::optional<SmallVector<Value>>
-  materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
-
-  /// Adds a 1:N target materialization to the converter. Such materializations
-  /// build IR that converts N values with target types into 1 value of the
-  /// source type.
-  void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
-    oneToNTargetMaterializations.emplace_back(std::move(callback));
-  }
-
-private:
-  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
 /// Stores a 1:N mapping of types and provides several useful accessors. This
 /// class extends `SignatureConversion`, which already supports 1:N type
 /// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
 /// not fail if some ops or types remain unconverted (i.e., the conversion is
 /// only "partial").
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns);
 
 /// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
-    OneToNTypeConverter converter;
+    TypeConverter converter;
     RewritePatternSet patterns(context);
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
                                                  Location loc, Type resultType,
                                                  ValueRange inputs,
                                                  Type originalType) const {
+  SmallVector<Value> result = materializeTargetConversion(
+      builder, loc, TypeRange(resultType), inputs, originalType);
+  if (result.empty())
+    return nullptr;
+  assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+                               "produced 1:N materialization");
+  return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+    OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+    Type originalType) const {
   for (const TargetMaterializationCallbackFn &fn :
-       llvm::reverse(targetMaterializations))
-    if (Value result = fn(builder, resultType, inputs, loc, originalType))
-      return result;
-  return nullptr;
+       llvm::reverse(targetMaterializations)) {
+    SmallVector<Value> result =
+        fn(builder, resultTypes, inputs, loc, originalType);
+    if (result.empty())
+      continue;
+    return result;
+  }
+  return {};
 }
 
 std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
 using namespace llvm;
 using namespace mlir;
 
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
-                                                 Location loc,
-                                                 TypeRange resultTypes,
-                                                 Value input) const {
-  for (const OneToNMaterializationCallbackFn &fn :
-       llvm::reverse(oneToNTargetMaterializations)) {
-    if (std::optional<SmallVector<Value>> result =
-            fn(builder, resultTypes, input, loc))
-      return *result;
-  }
-  return std::nullopt;
-}
-
 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
   TypeRange convertedTypes = getConvertedTypes();
   if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
 LogicalResult
 OneToNConversionPattern::matchAndRewrite(Operation *op,
                                          PatternRewriter &rewriter) const {
-  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+  auto *typeConverter = getTypeConverter();
 
   // Construct conversion mapping for results.
   Operation::result_type_range originalResultTypes = op->getResultTypes();
   OneToNTypeMapping resultMapping(originalResultTypes);
-  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
-                                               resultMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+                                                 resultMapping)))
     return failure();
 
   // Construct conversion mapping for operands.
   Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
   OneToNTypeMapping operandMapping(originalOperandTypes);
-  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
-                                               operandMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+                                                 operandMapping)))
     return failure();
 
   // Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
 // inserted by this pass are annotated with a string attribute that also
 // documents which kind of the cast (source, argument, or target).
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns) {
 #ifndef NDEBUG
   // Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
       // Target materialization.
       assert(!areOperandTypesLegal && areResultsTypesLegal &&
              operands.size() == 1 && "found unexpected target cast");
-      std::optional<SmallVector<Value>> maybeResults =
-          typeConverter.materializeTargetConversion(
-              rewriter, castOp->getLoc(), resultTypes, operands.front());
-      if (!maybeResults) {
+      materializedResults = typeConverter.materializeTargetConversion(
+          rewriter, castOp->getLoc(), resultTypes, operands.front());
+      if (materializedResults.empty()) {
         emitError(castOp->getLoc())
             << "failed to create target materialization";
         return failure();
       }
-      materializedResults = maybeResults.value();
     } else {
       // Source and argument materializations.
       assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
                                 const OneToNTypeMapping &resultMapping,
                                 ValueRange convertedOperands) const override {
     auto funcOp = cast<FunctionOpInterface>(op);
-    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+    auto *typeConverter = getTypeConverter();
 
     // Construct mapping for function arguments.
     OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
-                                                 argumentMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+                                                   argumentMapping)))
       return failure();
 
     // Construct mapping for function results.
     OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
-                                                 funcResultMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+                                                   funcResultMapping)))
       return failure();
 
     // Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
 ///
 /// This function has been copied (with small adaptions) from
 /// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
-                        Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+                                                  TypeRange resultTypes,
+                                                  ValueRange inputs,
+                                                  Location loc) {
+  if (inputs.size() != 1)
+    return {};
+  Value input = inputs.front();
+
   TupleType inputType = dyn_cast<TupleType>(input.getType());
   if (!inputType)
     return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   auto *context = &getContext();
 
   // Assemble type converter.
-  OneToNTypeConverter typeConverter;
+  TypeConverter typeConverter;
 
   typeConverter.addConversion([](Type type) { return type; });
   typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   typeConverter.addArgumentMaterialization(buildMakeTupleOp);
   typeConverter.addSourceMaterialization(buildMakeTupleOp);
   typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+  // Test the other target materialization variant that takes the original type
+  // as additional argument. This materialization function always fails.
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+         Location loc, Type originalType) -> SmallVector<Value> { return {}; });
 
   // Assemble patterns.
   RewritePatternSet patterns(context);

@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures.

1:1 target materializations (producing a single Value) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a SmallVector&lt;Value&gt;). Internally, all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the OneToNTypeConverter, simply switch all occurrences to TypeConverter.

Depends on #113031.


Full diff: https://github.com/llvm/llvm-project/pull/113032.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+41-13)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+1-44)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-4)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+14-30)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+14-4)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
 //===----------------------------------------------------------------------===//
 
 /// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
   SparseIterationTypeConverter();
 };
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..eb7da67c1bb995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
   /// conversion has finished.
   ///
   /// Note: Target materializations may optionally accept an additional Type
-  /// parameter, which is the original type of the SSA value.
+  /// parameter, which is the original type of the SSA value. Furthermore `T`
+  /// can be a TypeRange; in that case, the function must return a
+  /// SmallVector<Value>.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
   /// will be invoked with: outputType = "t3", inputs = "v2",
   // originalType = "t1". Note  that the original type "t1" cannot be recovered
   /// from just "t3" and "v2"; that's why the originalType parameter exists.
+  ///
+  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+  /// that case the materialization produces a SmallVector<Value>.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
                                     Type resultType, ValueRange inputs,
                                     Type originalType = {}) const;
+  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc,
+                                                 TypeRange resultType,
+                                                 ValueRange inputs,
+                                                 Type originalType = {}) const;
 
   /// Convert an attribute present `attr` from within the type `type` using
   /// the registered conversion functions. If no applicable conversion has been
@@ -341,8 +351,8 @@ class TypeConverter {
   /// The signature of the callback used to materialize a target conversion.
   ///
   /// Arguments: builder, result type, inputs, location, original type
-  using TargetMaterializationCallbackFn =
-      std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+      OpBuilder &, TypeRange, ValueRange, Location, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
   /// callback.
   ///
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
       TargetMaterializationCallbackFn>
   wrapTargetMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc, Type originalType) -> Value {
-      if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc, originalType);
-      return Value();
+               OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+               Location loc, Type originalType) -> SmallVector<Value> {
+      SmallVector<Value> result;
+      if constexpr (std::is_same<T, TypeRange>::value) {
+        // This is a 1:N target materialization. Return the produces values
+        // directly.
+        result = callback(builder, resultTypes, inputs, loc, originalType);
+      } else {
+        // This is a 1:1 target materialization. Invoke it only if the result
+        // type class of the callback matches the requested result type.
+        if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+          // 1:1 materializations produce single values, but we store 1:N
+          // target materialization functions in the type converter. Wrap the
+          // result value in a SmallVector<Value>.
+          std::optional<Value> val =
+              callback(builder, derivedType, inputs, loc, originalType);
+          if (val.has_value() && *val)
+            result.push_back(*val);
+        }
+      }
+      return result;
     };
   }
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location)`
+  /// - Value(OpBuilder &, T, ValueRange, Location)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
   wrapTargetMaterialization(FnT &&callback) const {
     return wrapTargetMaterialization<T>(
         [callback = std::forward<FnT>(callback)](
-            OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
-            Type originalType) -> Value {
-          return callback(builder, resultType, inputs, loc);
+            OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+            Type originalType) {
+          return callback(builder, resultTypes, inputs, loc);
         });
   }
 
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
 
 namespace mlir {
 
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
-  /// Callback that expresses user-provided materialization logic from the given
-  /// value to N values of the given types. This is useful for expressing target
-  /// materializations for 1:N type conversions, which materialize one value in
-  /// a source type as N values in target types.
-  using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
-
-  /// Creates the mapping of the given range of original types to target types
-  /// of the conversion and stores that mapping in the given (signature)
-  /// conversion. This function simply calls
-  /// `TypeConverter::convertSignatureArgs` and exists here with a different
-  /// name to reflect the broader semantic.
-  LogicalResult computeTypeMapping(TypeRange types,
-                                   SignatureConversion &result) const {
-    return convertSignatureArgs(types, result);
-  }
-
-  /// Applies one of the user-provided 1:N target materializations. If several
-  /// exists, they are tried out in the reverse order in which they have been
-  /// added until the first one succeeds. If none succeeds, the functions
-  /// returns `std::nullopt`.
-  std::optional<SmallVector<Value>>
-  materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
-
-  /// Adds a 1:N target materialization to the converter. Such materializations
-  /// build IR that converts N values with target types into 1 value of the
-  /// source type.
-  void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
-    oneToNTargetMaterializations.emplace_back(std::move(callback));
-  }
-
-private:
-  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
 /// Stores a 1:N mapping of types and provides several useful accessors. This
 /// class extends `SignatureConversion`, which already supports 1:N type
 /// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
 /// not fail if some ops or types remain unconverted (i.e., the conversion is
 /// only "partial").
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns);
 
 /// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
-    OneToNTypeConverter converter;
+    TypeConverter converter;
     RewritePatternSet patterns(context);
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
                                                  Location loc, Type resultType,
                                                  ValueRange inputs,
                                                  Type originalType) const {
+  SmallVector<Value> result = materializeTargetConversion(
+      builder, loc, TypeRange(resultType), inputs, originalType);
+  if (result.empty())
+    return nullptr;
+  assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+                               "produced 1:N materialization");
+  return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+    OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+    Type originalType) const {
   for (const TargetMaterializationCallbackFn &fn :
-       llvm::reverse(targetMaterializations))
-    if (Value result = fn(builder, resultType, inputs, loc, originalType))
-      return result;
-  return nullptr;
+       llvm::reverse(targetMaterializations)) {
+    SmallVector<Value> result =
+        fn(builder, resultTypes, inputs, loc, originalType);
+    if (result.empty())
+      continue;
+    return result;
+  }
+  return {};
 }
 
 std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
 using namespace llvm;
 using namespace mlir;
 
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
-                                                 Location loc,
-                                                 TypeRange resultTypes,
-                                                 Value input) const {
-  for (const OneToNMaterializationCallbackFn &fn :
-       llvm::reverse(oneToNTargetMaterializations)) {
-    if (std::optional<SmallVector<Value>> result =
-            fn(builder, resultTypes, input, loc))
-      return *result;
-  }
-  return std::nullopt;
-}
-
 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
   TypeRange convertedTypes = getConvertedTypes();
   if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
 LogicalResult
 OneToNConversionPattern::matchAndRewrite(Operation *op,
                                          PatternRewriter &rewriter) const {
-  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+  auto *typeConverter = getTypeConverter();
 
   // Construct conversion mapping for results.
   Operation::result_type_range originalResultTypes = op->getResultTypes();
   OneToNTypeMapping resultMapping(originalResultTypes);
-  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
-                                               resultMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+                                                 resultMapping)))
     return failure();
 
   // Construct conversion mapping for operands.
   Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
   OneToNTypeMapping operandMapping(originalOperandTypes);
-  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
-                                               operandMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+                                                 operandMapping)))
     return failure();
 
   // Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
 // inserted by this pass are annotated with a string attribute that also
 // documents which kind of the cast (source, argument, or target).
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns) {
 #ifndef NDEBUG
   // Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
       // Target materialization.
       assert(!areOperandTypesLegal && areResultsTypesLegal &&
              operands.size() == 1 && "found unexpected target cast");
-      std::optional<SmallVector<Value>> maybeResults =
-          typeConverter.materializeTargetConversion(
-              rewriter, castOp->getLoc(), resultTypes, operands.front());
-      if (!maybeResults) {
+      materializedResults = typeConverter.materializeTargetConversion(
+          rewriter, castOp->getLoc(), resultTypes, operands.front());
+      if (materializedResults.empty()) {
         emitError(castOp->getLoc())
             << "failed to create target materialization";
         return failure();
       }
-      materializedResults = maybeResults.value();
     } else {
       // Source and argument materializations.
       assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
                                 const OneToNTypeMapping &resultMapping,
                                 ValueRange convertedOperands) const override {
     auto funcOp = cast<FunctionOpInterface>(op);
-    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+    auto *typeConverter = getTypeConverter();
 
     // Construct mapping for function arguments.
     OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
-                                                 argumentMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+                                                   argumentMapping)))
       return failure();
 
     // Construct mapping for function results.
     OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
-                                                 funcResultMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+                                                   funcResultMapping)))
       return failure();
 
     // Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
 ///
 /// This function has been copied (with small adaptions) from
 /// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
-                        Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+                                                  TypeRange resultTypes,
+                                                  ValueRange inputs,
+                                                  Location loc) {
+  if (inputs.size() != 1)
+    return {};
+  Value input = inputs.front();
+
   TupleType inputType = dyn_cast<TupleType>(input.getType());
   if (!inputType)
     return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   auto *context = &getContext();
 
   // Assemble type converter.
-  OneToNTypeConverter typeConverter;
+  TypeConverter typeConverter;
 
   typeConverter.addConversion([](Type type) { return type; });
   typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   typeConverter.addArgumentMaterialization(buildMakeTupleOp);
   typeConverter.addSourceMaterialization(buildMakeTupleOp);
   typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+  // Test the other target materialization variant that takes the original type
+  // as additional argument. This materialization function always fails.
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+         Location loc, Type originalType) -> SmallVector<Value> { return {}; });
 
   // Assemble patterns.
   RewritePatternSet patterns(context);

@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2024

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures.

1:1 target materializations (producing a single Value) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a SmallVector&lt;Value&gt;). Internally, all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the OneToNTypeConverter, simply switch all occurrences to TypeConverter.

Depends on #113031.


Full diff: https://github.com/llvm/llvm-project/pull/113032.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+41-13)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+1-44)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-4)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+14-30)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+14-4)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
 //===----------------------------------------------------------------------===//
 
 /// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
   SparseIterationTypeConverter();
 };
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..eb7da67c1bb995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
   /// conversion has finished.
   ///
   /// Note: Target materializations may optionally accept an additional Type
-  /// parameter, which is the original type of the SSA value.
+  /// parameter, which is the original type of the SSA value. Furthermore `T`
+  /// can be a TypeRange; in that case, the function must return a
+  /// SmallVector<Value>.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
   /// will be invoked with: outputType = "t3", inputs = "v2",
   // originalType = "t1". Note  that the original type "t1" cannot be recovered
   /// from just "t3" and "v2"; that's why the originalType parameter exists.
+  ///
+  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+  /// that case the materialization produces a SmallVector<Value>.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
                                     Type resultType, ValueRange inputs,
                                     Type originalType = {}) const;
+  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc,
+                                                 TypeRange resultType,
+                                                 ValueRange inputs,
+                                                 Type originalType = {}) const;
 
   /// Convert an attribute present `attr` from within the type `type` using
   /// the registered conversion functions. If no applicable conversion has been
@@ -341,8 +351,8 @@ class TypeConverter {
   /// The signature of the callback used to materialize a target conversion.
   ///
   /// Arguments: builder, result type, inputs, location, original type
-  using TargetMaterializationCallbackFn =
-      std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+      OpBuilder &, TypeRange, ValueRange, Location, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
   /// callback.
   ///
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
       TargetMaterializationCallbackFn>
   wrapTargetMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc, Type originalType) -> Value {
-      if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc, originalType);
-      return Value();
+               OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+               Location loc, Type originalType) -> SmallVector<Value> {
+      SmallVector<Value> result;
+      if constexpr (std::is_same<T, TypeRange>::value) {
+        // This is a 1:N target materialization. Return the produces values
+        // directly.
+        result = callback(builder, resultTypes, inputs, loc, originalType);
+      } else {
+        // This is a 1:1 target materialization. Invoke it only if the result
+        // type class of the callback matches the requested result type.
+        if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+          // 1:1 materializations produce single values, but we store 1:N
+          // target materialization functions in the type converter. Wrap the
+          // result value in a SmallVector<Value>.
+          std::optional<Value> val =
+              callback(builder, derivedType, inputs, loc, originalType);
+          if (val.has_value() && *val)
+            result.push_back(*val);
+        }
+      }
+      return result;
     };
   }
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location)`
+  /// - Value(OpBuilder &, T, ValueRange, Location)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
   wrapTargetMaterialization(FnT &&callback) const {
     return wrapTargetMaterialization<T>(
         [callback = std::forward<FnT>(callback)](
-            OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
-            Type originalType) -> Value {
-          return callback(builder, resultType, inputs, loc);
+            OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+            Type originalType) {
+          return callback(builder, resultTypes, inputs, loc);
         });
   }
 
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
 
 namespace mlir {
 
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
-  /// Callback that expresses user-provided materialization logic from the given
-  /// value to N values of the given types. This is useful for expressing target
-  /// materializations for 1:N type conversions, which materialize one value in
-  /// a source type as N values in target types.
-  using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
-
-  /// Creates the mapping of the given range of original types to target types
-  /// of the conversion and stores that mapping in the given (signature)
-  /// conversion. This function simply calls
-  /// `TypeConverter::convertSignatureArgs` and exists here with a different
-  /// name to reflect the broader semantic.
-  LogicalResult computeTypeMapping(TypeRange types,
-                                   SignatureConversion &result) const {
-    return convertSignatureArgs(types, result);
-  }
-
-  /// Applies one of the user-provided 1:N target materializations. If several
-  /// exists, they are tried out in the reverse order in which they have been
-  /// added until the first one succeeds. If none succeeds, the functions
-  /// returns `std::nullopt`.
-  std::optional<SmallVector<Value>>
-  materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
-
-  /// Adds a 1:N target materialization to the converter. Such materializations
-  /// build IR that converts N values with target types into 1 value of the
-  /// source type.
-  void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
-    oneToNTargetMaterializations.emplace_back(std::move(callback));
-  }
-
-private:
-  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
 /// Stores a 1:N mapping of types and provides several useful accessors. This
 /// class extends `SignatureConversion`, which already supports 1:N type
 /// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
 /// not fail if some ops or types remain unconverted (i.e., the conversion is
 /// only "partial").
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns);
 
 /// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
-    OneToNTypeConverter converter;
+    TypeConverter converter;
     RewritePatternSet patterns(context);
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
                                                  Location loc, Type resultType,
                                                  ValueRange inputs,
                                                  Type originalType) const {
+  SmallVector<Value> result = materializeTargetConversion(
+      builder, loc, TypeRange(resultType), inputs, originalType);
+  if (result.empty())
+    return nullptr;
+  assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+                               "produced 1:N materialization");
+  return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+    OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+    Type originalType) const {
   for (const TargetMaterializationCallbackFn &fn :
-       llvm::reverse(targetMaterializations))
-    if (Value result = fn(builder, resultType, inputs, loc, originalType))
-      return result;
-  return nullptr;
+       llvm::reverse(targetMaterializations)) {
+    SmallVector<Value> result =
+        fn(builder, resultTypes, inputs, loc, originalType);
+    if (result.empty())
+      continue;
+    return result;
+  }
+  return {};
 }
 
 std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
 using namespace llvm;
 using namespace mlir;
 
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
-                                                 Location loc,
-                                                 TypeRange resultTypes,
-                                                 Value input) const {
-  for (const OneToNMaterializationCallbackFn &fn :
-       llvm::reverse(oneToNTargetMaterializations)) {
-    if (std::optional<SmallVector<Value>> result =
-            fn(builder, resultTypes, input, loc))
-      return *result;
-  }
-  return std::nullopt;
-}
-
 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
   TypeRange convertedTypes = getConvertedTypes();
   if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
 LogicalResult
 OneToNConversionPattern::matchAndRewrite(Operation *op,
                                          PatternRewriter &rewriter) const {
-  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+  auto *typeConverter = getTypeConverter();
 
   // Construct conversion mapping for results.
   Operation::result_type_range originalResultTypes = op->getResultTypes();
   OneToNTypeMapping resultMapping(originalResultTypes);
-  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
-                                               resultMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+                                                 resultMapping)))
     return failure();
 
   // Construct conversion mapping for operands.
   Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
   OneToNTypeMapping operandMapping(originalOperandTypes);
-  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
-                                               operandMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+                                                 operandMapping)))
     return failure();
 
   // Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
 // inserted by this pass are annotated with a string attribute that also
 // documents which kind of the cast (source, argument, or target).
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns) {
 #ifndef NDEBUG
   // Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
       // Target materialization.
       assert(!areOperandTypesLegal && areResultsTypesLegal &&
              operands.size() == 1 && "found unexpected target cast");
-      std::optional<SmallVector<Value>> maybeResults =
-          typeConverter.materializeTargetConversion(
-              rewriter, castOp->getLoc(), resultTypes, operands.front());
-      if (!maybeResults) {
+      materializedResults = typeConverter.materializeTargetConversion(
+          rewriter, castOp->getLoc(), resultTypes, operands.front());
+      if (materializedResults.empty()) {
         emitError(castOp->getLoc())
             << "failed to create target materialization";
         return failure();
       }
-      materializedResults = maybeResults.value();
     } else {
       // Source and argument materializations.
       assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
                                 const OneToNTypeMapping &resultMapping,
                                 ValueRange convertedOperands) const override {
     auto funcOp = cast<FunctionOpInterface>(op);
-    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+    auto *typeConverter = getTypeConverter();
 
     // Construct mapping for function arguments.
     OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
-                                                 argumentMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+                                                   argumentMapping)))
       return failure();
 
     // Construct mapping for function results.
     OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
-                                                 funcResultMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+                                                   funcResultMapping)))
       return failure();
 
     // Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
 ///
 /// This function has been copied (with small adaptions) from
 /// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
-                        Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+                                                  TypeRange resultTypes,
+                                                  ValueRange inputs,
+                                                  Location loc) {
+  if (inputs.size() != 1)
+    return {};
+  Value input = inputs.front();
+
   TupleType inputType = dyn_cast<TupleType>(input.getType());
   if (!inputType)
     return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   auto *context = &getContext();
 
   // Assemble type converter.
-  OneToNTypeConverter typeConverter;
+  TypeConverter typeConverter;
 
   typeConverter.addConversion([](Type type) { return type; });
   typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   typeConverter.addArgumentMaterialization(buildMakeTupleOp);
   typeConverter.addSourceMaterialization(buildMakeTupleOp);
   typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+  // Test the other target materialization variant that takes the original type
+  // as additional argument. This materialization function always fails.
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+         Location loc, Type originalType) -> SmallVector<Value> { return {}; });
 
   // Assemble patterns.
   RewritePatternSet patterns(context);

@matthias-springer matthias-springer force-pushed the users/matthias-springer/merge_type_converters branch from 8fc8147 to 7ac8aa2 Compare October 19, 2024 10:47
@matthias-springer matthias-springer force-pushed the users/matthias-springer/mat_remove_optional branch from 2fe3e2d to 59a9bbb Compare October 23, 2024 14:16
Base automatically changed from users/matthias-springer/mat_remove_optional to main October 23, 2024 14:29
@matthias-springer matthias-springer force-pushed the users/matthias-springer/merge_type_converters branch 2 times, most recently from 216d522 to 44db7a9 Compare October 23, 2024 16:55
@matthias-springer matthias-springer force-pushed the users/matthias-springer/merge_type_converters branch from eaff670 to 43edbab Compare October 25, 2024 17:52
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer merged commit 8c4bc1e into main Oct 25, 2024
6 of 7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/merge_type_converters branch October 25, 2024 18:44
@llvm-ci
Copy link
Collaborator

llvm-ci commented Oct 25, 2024

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building mlir at step 5 "build-check-mlir-build-only".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/5467

Here is the relevant piece of the build log for the reference
Step 5 (build-check-mlir-build-only) failure: build (failure)
...
364.284 [638/16/3937] Linking CXX executable tools/mlir/unittests/Dialect/SparseTensor/MLIRSparseTensorTests
364.293 [637/16/3938] Building CXX object tools/mlir/unittests/Interfaces/CMakeFiles/MLIRInterfacesTests.dir/InferIntRangeInterfaceTest.cpp.o
364.314 [636/16/3939] Building CXX object tools/mlir/unittests/Interfaces/CMakeFiles/MLIRInterfacesTests.dir/InferTypeOpInterfaceTest.cpp.o
364.321 [635/16/3940] Building CXX object tools/mlir/unittests/Pass/CMakeFiles/MLIRPassTests.dir/PassPipelineParserTest.cpp.o
364.350 [634/16/3941] Building CXX object tools/mlir/unittests/Pass/CMakeFiles/MLIRPassTests.dir/AnalysisManagerTest.cpp.o
364.378 [633/16/3942] Building CXX object tools/mlir/unittests/Pass/CMakeFiles/MLIRPassTests.dir/PassManagerTest.cpp.o
364.588 [632/16/3943] Linking CXX executable tools/mlir/unittests/Interfaces/MLIRInterfacesTests
364.589 [631/16/3944] Linking CXX executable tools/mlir/unittests/Pass/MLIRPassTests
364.618 [630/16/3945] Building CXX object tools/mlir/unittests/Rewrite/CMakeFiles/MLIRRewriteTests.dir/PatternBenefit.cpp.o
365.215 [629/16/3946] Building CXX object tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o
FAILED: tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o 
CCACHE_CPP2=yes CCACHE_HASHDIR=yes /usr/bin/ccache /usr/bin/g++-7 -DGTEST_HAS_RTTI=0 -D_DEBUG -D_GLIBCXX_ASSERTIONS -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/tools/mlir/lib/Transforms/Utils -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/include -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/tools/mlir/include -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -fno-lifetime-dse -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wno-missing-field-initializers -pedantic -Wno-long-long -Wimplicit-fallthrough -Wno-uninitialized -Wno-nonnull -Wno-noexcept-type -Wdelete-non-virtual-dtor -Wno-comment -Wno-misleading-indentation -fdiagnostics-color -ffunction-sections -fdata-sections -Wundef -Wno-unused-but-set-parameter -Wno-deprecated-copy -O3 -DNDEBUG  -fno-exceptions -funwind-tables -fno-rtti -UNDEBUG -std=c++1z -MD -MT tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o -MF tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o.d -o tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o -c /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp
In file included from /usr/include/c++/7/cassert:44:0,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include/llvm/Support/Error.h:26,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include/llvm/Support/JSON.h:54,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include/llvm/Support/ScopedPrinter.h:19,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp:24:
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp: In member function ‘llvm::SmallVector<mlir::Value> mlir::TypeConverter::materializeTargetConversion(mlir::OpBuilder&, mlir::Location, mlir::TypeRange, mlir::ValueRange, mlir::Type) const’:
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp:2851:28: error: call of overloaded ‘TypeRange(llvm::SmallVector<mlir::Value>&)’ is ambiguous
     assert(TypeRange(result) == resultTypes &&
                            ^
In file included from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Support/TypeID.h:20:0,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/MLIRContext.h:13,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/DialectRegistry.h:16,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/Dialect.h:16,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/OpDefinition.h:22,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/Builders.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/PatternMatch.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Transforms/DialectConversion.h:17,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp:9:
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include/llvm/ADT/STLExtras.h:1278:3: note: candidate: llvm::detail::indexed_accessor_range_base<DerivedT, BaseT, T, PointerT, ReferenceT>::indexed_accessor_range_base(const llvm::iterator_range<llvm::detail::indexed_accessor_range_base<DerivedT, BaseT, T, PointerT, ReferenceT>::iterator>&) [with DerivedT = mlir::TypeRange; BaseT = llvm::PointerUnion<const mlir::Value*, const mlir::Type*, mlir::OpOperand*, mlir::detail::OpResultImpl*>; T = mlir::Type; PointerT = mlir::Type; ReferenceT = mlir::Type]
   indexed_accessor_range_base(const iterator_range<iterator> &range)
   ^~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/OperationSupport.h:23:0,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/Dialect.h:17,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/OpDefinition.h:22,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/Builders.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/PatternMatch.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Transforms/DialectConversion.h:17,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp:9:
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/TypeRange.h:38:21: note:   inherited here
   using RangeBaseT::RangeBaseT;
                     ^~~~~~~~~~
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/TypeRange.h:42:12: note: candidate: mlir::TypeRange::TypeRange(mlir::ValueRange)
   explicit TypeRange(ValueRange values);
            ^~~~~~~~~
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp: At global scope:

Copy link
Contributor

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Nice progress!

matthias-springer added a commit that referenced this pull request Oct 30, 2024
The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032,tThe dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)
matthias-springer added a commit that referenced this pull request Oct 30, 2024
The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032,tThe dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)
matthias-springer added a commit that referenced this pull request Oct 30, 2024
…114192)

The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround
around missing 1:N support in the dialect conversion. Since #113032, the
dialect conversion infrastructure supports 1:N type conversions and 1:N
target materializations. The `ValueDecomposer` class is no longer
needed. (However, target materializations must still be inserted
manually, until we fully merge the 1:1 and 1:N drivers.)

Note for LLVM integration: Register 1:N target materializations on the
type converter instead of "decompose value conversions" on the
`ValueDecomposer`.
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
…lvm#114192)

The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround
around missing 1:N support in the dialect conversion. Since llvm#113032, the
dialect conversion infrastructure supports 1:N type conversions and 1:N
target materializations. The `ValueDecomposer` class is no longer
needed. (However, target materializations must still be inserted
manually, until we fully merge the 1:1 and 1:N drivers.)

Note for LLVM integration: Register 1:N target materializations on the
type converter instead of "decompose value conversions" on the
`ValueDecomposer`.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
The 1:N type converter derived from the 1:1 type converter and extends
it with 1:N target materializations. This commit merges the two type
converters and stores 1:N target materializations in the 1:1 type
converter. This is in preparation of merging the 1:1 and 1:N dialect
conversion infrastructures.

1:1 target materializations (producing a single `Value`) will remain
valid. An additional API is added to the type converter to register 1:N
target materializations (producing a `SmallVector<Value>`). Internally,
all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the `OneToNTypeConverter`,
simply switch all occurrences to `TypeConverter`.

---------

Co-authored-by: Markus Böck <[email protected]>
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
…lvm#114192)

The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround
around missing 1:N support in the dialect conversion. Since llvm#113032, the
dialect conversion infrastructure supports 1:N type conversions and 1:N
target materializations. The `ValueDecomposer` class is no longer
needed. (However, target materializations must still be inserted
manually, until we fully merge the 1:1 and 1:N drivers.)

Note for LLVM integration: Register 1:N target materializations on the
type converter instead of "decompose value conversions" on the
`ValueDecomposer`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:sme mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants