Skip to content

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Oct 19, 2024

The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures.

1:1 target materializations (producing a single Value) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a SmallVector<Value>). Internally, all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the OneToNTypeConverter, simply switch all occurrences to TypeConverter.

@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2024

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir-sparse

Author: Matthias Springer (matthias-springer)

Changes

The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures.

1:1 target materializations (producing a single Value) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a SmallVector&lt;Value&gt;). Internally, all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the OneToNTypeConverter, simply switch all occurrences to TypeConverter.

Depends on #113031.


Full diff: https://github.com/llvm/llvm-project/pull/113032.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+41-13)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+1-44)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-4)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+14-30)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+14-4)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
 //===----------------------------------------------------------------------===//
 
 /// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
   SparseIterationTypeConverter();
 };
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..eb7da67c1bb995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
   /// conversion has finished.
   ///
   /// Note: Target materializations may optionally accept an additional Type
-  /// parameter, which is the original type of the SSA value.
+  /// parameter, which is the original type of the SSA value. Furthermore `T`
+  /// can be a TypeRange; in that case, the function must return a
+  /// SmallVector<Value>.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
   /// will be invoked with: outputType = "t3", inputs = "v2",
   // originalType = "t1". Note  that the original type "t1" cannot be recovered
   /// from just "t3" and "v2"; that's why the originalType parameter exists.
+  ///
+  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+  /// that case the materialization produces a SmallVector<Value>.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
                                     Type resultType, ValueRange inputs,
                                     Type originalType = {}) const;
+  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc,
+                                                 TypeRange resultType,
+                                                 ValueRange inputs,
+                                                 Type originalType = {}) const;
 
   /// Convert an attribute present `attr` from within the type `type` using
   /// the registered conversion functions. If no applicable conversion has been
@@ -341,8 +351,8 @@ class TypeConverter {
   /// The signature of the callback used to materialize a target conversion.
   ///
   /// Arguments: builder, result type, inputs, location, original type
-  using TargetMaterializationCallbackFn =
-      std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+      OpBuilder &, TypeRange, ValueRange, Location, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
   /// callback.
   ///
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
       TargetMaterializationCallbackFn>
   wrapTargetMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc, Type originalType) -> Value {
-      if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc, originalType);
-      return Value();
+               OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+               Location loc, Type originalType) -> SmallVector<Value> {
+      SmallVector<Value> result;
+      if constexpr (std::is_same<T, TypeRange>::value) {
+        // This is a 1:N target materialization. Return the produces values
+        // directly.
+        result = callback(builder, resultTypes, inputs, loc, originalType);
+      } else {
+        // This is a 1:1 target materialization. Invoke it only if the result
+        // type class of the callback matches the requested result type.
+        if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+          // 1:1 materializations produce single values, but we store 1:N
+          // target materialization functions in the type converter. Wrap the
+          // result value in a SmallVector<Value>.
+          std::optional<Value> val =
+              callback(builder, derivedType, inputs, loc, originalType);
+          if (val.has_value() && *val)
+            result.push_back(*val);
+        }
+      }
+      return result;
     };
   }
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location)`
+  /// - Value(OpBuilder &, T, ValueRange, Location)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
   wrapTargetMaterialization(FnT &&callback) const {
     return wrapTargetMaterialization<T>(
         [callback = std::forward<FnT>(callback)](
-            OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
-            Type originalType) -> Value {
-          return callback(builder, resultType, inputs, loc);
+            OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+            Type originalType) {
+          return callback(builder, resultTypes, inputs, loc);
         });
   }
 
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
 
 namespace mlir {
 
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
-  /// Callback that expresses user-provided materialization logic from the given
-  /// value to N values of the given types. This is useful for expressing target
-  /// materializations for 1:N type conversions, which materialize one value in
-  /// a source type as N values in target types.
-  using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
-
-  /// Creates the mapping of the given range of original types to target types
-  /// of the conversion and stores that mapping in the given (signature)
-  /// conversion. This function simply calls
-  /// `TypeConverter::convertSignatureArgs` and exists here with a different
-  /// name to reflect the broader semantic.
-  LogicalResult computeTypeMapping(TypeRange types,
-                                   SignatureConversion &result) const {
-    return convertSignatureArgs(types, result);
-  }
-
-  /// Applies one of the user-provided 1:N target materializations. If several
-  /// exists, they are tried out in the reverse order in which they have been
-  /// added until the first one succeeds. If none succeeds, the functions
-  /// returns `std::nullopt`.
-  std::optional<SmallVector<Value>>
-  materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
-
-  /// Adds a 1:N target materialization to the converter. Such materializations
-  /// build IR that converts N values with target types into 1 value of the
-  /// source type.
-  void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
-    oneToNTargetMaterializations.emplace_back(std::move(callback));
-  }
-
-private:
-  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
 /// Stores a 1:N mapping of types and provides several useful accessors. This
 /// class extends `SignatureConversion`, which already supports 1:N type
 /// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
 /// not fail if some ops or types remain unconverted (i.e., the conversion is
 /// only "partial").
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns);
 
 /// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
-    OneToNTypeConverter converter;
+    TypeConverter converter;
     RewritePatternSet patterns(context);
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
                                                  Location loc, Type resultType,
                                                  ValueRange inputs,
                                                  Type originalType) const {
+  SmallVector<Value> result = materializeTargetConversion(
+      builder, loc, TypeRange(resultType), inputs, originalType);
+  if (result.empty())
+    return nullptr;
+  assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+                               "produced 1:N materialization");
+  return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+    OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+    Type originalType) const {
   for (const TargetMaterializationCallbackFn &fn :
-       llvm::reverse(targetMaterializations))
-    if (Value result = fn(builder, resultType, inputs, loc, originalType))
-      return result;
-  return nullptr;
+       llvm::reverse(targetMaterializations)) {
+    SmallVector<Value> result =
+        fn(builder, resultTypes, inputs, loc, originalType);
+    if (result.empty())
+      continue;
+    return result;
+  }
+  return {};
 }
 
 std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
 using namespace llvm;
 using namespace mlir;
 
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
-                                                 Location loc,
-                                                 TypeRange resultTypes,
-                                                 Value input) const {
-  for (const OneToNMaterializationCallbackFn &fn :
-       llvm::reverse(oneToNTargetMaterializations)) {
-    if (std::optional<SmallVector<Value>> result =
-            fn(builder, resultTypes, input, loc))
-      return *result;
-  }
-  return std::nullopt;
-}
-
 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
   TypeRange convertedTypes = getConvertedTypes();
   if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
 LogicalResult
 OneToNConversionPattern::matchAndRewrite(Operation *op,
                                          PatternRewriter &rewriter) const {
-  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+  auto *typeConverter = getTypeConverter();
 
   // Construct conversion mapping for results.
   Operation::result_type_range originalResultTypes = op->getResultTypes();
   OneToNTypeMapping resultMapping(originalResultTypes);
-  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
-                                               resultMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+                                                 resultMapping)))
     return failure();
 
   // Construct conversion mapping for operands.
   Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
   OneToNTypeMapping operandMapping(originalOperandTypes);
-  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
-                                               operandMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+                                                 operandMapping)))
     return failure();
 
   // Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
 // inserted by this pass are annotated with a string attribute that also
 // documents which kind of the cast (source, argument, or target).
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns) {
 #ifndef NDEBUG
   // Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
       // Target materialization.
       assert(!areOperandTypesLegal && areResultsTypesLegal &&
              operands.size() == 1 && "found unexpected target cast");
-      std::optional<SmallVector<Value>> maybeResults =
-          typeConverter.materializeTargetConversion(
-              rewriter, castOp->getLoc(), resultTypes, operands.front());
-      if (!maybeResults) {
+      materializedResults = typeConverter.materializeTargetConversion(
+          rewriter, castOp->getLoc(), resultTypes, operands.front());
+      if (materializedResults.empty()) {
         emitError(castOp->getLoc())
             << "failed to create target materialization";
         return failure();
       }
-      materializedResults = maybeResults.value();
     } else {
       // Source and argument materializations.
       assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
                                 const OneToNTypeMapping &resultMapping,
                                 ValueRange convertedOperands) const override {
     auto funcOp = cast<FunctionOpInterface>(op);
-    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+    auto *typeConverter = getTypeConverter();
 
     // Construct mapping for function arguments.
     OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
-                                                 argumentMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+                                                   argumentMapping)))
       return failure();
 
     // Construct mapping for function results.
     OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
-                                                 funcResultMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+                                                   funcResultMapping)))
       return failure();
 
     // Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
 ///
 /// This function has been copied (with small adaptions) from
 /// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
-                        Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+                                                  TypeRange resultTypes,
+                                                  ValueRange inputs,
+                                                  Location loc) {
+  if (inputs.size() != 1)
+    return {};
+  Value input = inputs.front();
+
   TupleType inputType = dyn_cast<TupleType>(input.getType());
   if (!inputType)
     return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   auto *context = &getContext();
 
   // Assemble type converter.
-  OneToNTypeConverter typeConverter;
+  TypeConverter typeConverter;
 
   typeConverter.addConversion([](Type type) { return type; });
   typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   typeConverter.addArgumentMaterialization(buildMakeTupleOp);
   typeConverter.addSourceMaterialization(buildMakeTupleOp);
   typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+  // Test the other target materialization variant that takes the original type
+  // as additional argument. This materialization function always fails.
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+         Location loc, Type originalType) -> SmallVector<Value> { return {}; });
 
   // Assemble patterns.
   RewritePatternSet patterns(context);

@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures.

1:1 target materializations (producing a single Value) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a SmallVector&lt;Value&gt;). Internally, all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the OneToNTypeConverter, simply switch all occurrences to TypeConverter.

Depends on #113031.


Full diff: https://github.com/llvm/llvm-project/pull/113032.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+41-13)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+1-44)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-4)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+14-30)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+14-4)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
 //===----------------------------------------------------------------------===//
 
 /// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
   SparseIterationTypeConverter();
 };
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..eb7da67c1bb995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
   /// conversion has finished.
   ///
   /// Note: Target materializations may optionally accept an additional Type
-  /// parameter, which is the original type of the SSA value.
+  /// parameter, which is the original type of the SSA value. Furthermore `T`
+  /// can be a TypeRange; in that case, the function must return a
+  /// SmallVector<Value>.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
   /// will be invoked with: outputType = "t3", inputs = "v2",
   // originalType = "t1". Note  that the original type "t1" cannot be recovered
   /// from just "t3" and "v2"; that's why the originalType parameter exists.
+  ///
+  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+  /// that case the materialization produces a SmallVector<Value>.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
                                     Type resultType, ValueRange inputs,
                                     Type originalType = {}) const;
+  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc,
+                                                 TypeRange resultType,
+                                                 ValueRange inputs,
+                                                 Type originalType = {}) const;
 
   /// Convert an attribute present `attr` from within the type `type` using
   /// the registered conversion functions. If no applicable conversion has been
@@ -341,8 +351,8 @@ class TypeConverter {
   /// The signature of the callback used to materialize a target conversion.
   ///
   /// Arguments: builder, result type, inputs, location, original type
-  using TargetMaterializationCallbackFn =
-      std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+      OpBuilder &, TypeRange, ValueRange, Location, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
   /// callback.
   ///
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
       TargetMaterializationCallbackFn>
   wrapTargetMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc, Type originalType) -> Value {
-      if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc, originalType);
-      return Value();
+               OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+               Location loc, Type originalType) -> SmallVector<Value> {
+      SmallVector<Value> result;
+      if constexpr (std::is_same<T, TypeRange>::value) {
+        // This is a 1:N target materialization. Return the produces values
+        // directly.
+        result = callback(builder, resultTypes, inputs, loc, originalType);
+      } else {
+        // This is a 1:1 target materialization. Invoke it only if the result
+        // type class of the callback matches the requested result type.
+        if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+          // 1:1 materializations produce single values, but we store 1:N
+          // target materialization functions in the type converter. Wrap the
+          // result value in a SmallVector<Value>.
+          std::optional<Value> val =
+              callback(builder, derivedType, inputs, loc, originalType);
+          if (val.has_value() && *val)
+            result.push_back(*val);
+        }
+      }
+      return result;
     };
   }
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location)`
+  /// - Value(OpBuilder &, T, ValueRange, Location)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
   wrapTargetMaterialization(FnT &&callback) const {
     return wrapTargetMaterialization<T>(
         [callback = std::forward<FnT>(callback)](
-            OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
-            Type originalType) -> Value {
-          return callback(builder, resultType, inputs, loc);
+            OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+            Type originalType) {
+          return callback(builder, resultTypes, inputs, loc);
         });
   }
 
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
 
 namespace mlir {
 
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
-  /// Callback that expresses user-provided materialization logic from the given
-  /// value to N values of the given types. This is useful for expressing target
-  /// materializations for 1:N type conversions, which materialize one value in
-  /// a source type as N values in target types.
-  using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
-
-  /// Creates the mapping of the given range of original types to target types
-  /// of the conversion and stores that mapping in the given (signature)
-  /// conversion. This function simply calls
-  /// `TypeConverter::convertSignatureArgs` and exists here with a different
-  /// name to reflect the broader semantic.
-  LogicalResult computeTypeMapping(TypeRange types,
-                                   SignatureConversion &result) const {
-    return convertSignatureArgs(types, result);
-  }
-
-  /// Applies one of the user-provided 1:N target materializations. If several
-  /// exists, they are tried out in the reverse order in which they have been
-  /// added until the first one succeeds. If none succeeds, the functions
-  /// returns `std::nullopt`.
-  std::optional<SmallVector<Value>>
-  materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
-
-  /// Adds a 1:N target materialization to the converter. Such materializations
-  /// build IR that converts N values with target types into 1 value of the
-  /// source type.
-  void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
-    oneToNTargetMaterializations.emplace_back(std::move(callback));
-  }
-
-private:
-  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
 /// Stores a 1:N mapping of types and provides several useful accessors. This
 /// class extends `SignatureConversion`, which already supports 1:N type
 /// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
 /// not fail if some ops or types remain unconverted (i.e., the conversion is
 /// only "partial").
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns);
 
 /// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
-    OneToNTypeConverter converter;
+    TypeConverter converter;
     RewritePatternSet patterns(context);
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
                                                  Location loc, Type resultType,
                                                  ValueRange inputs,
                                                  Type originalType) const {
+  SmallVector<Value> result = materializeTargetConversion(
+      builder, loc, TypeRange(resultType), inputs, originalType);
+  if (result.empty())
+    return nullptr;
+  assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+                               "produced 1:N materialization");
+  return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+    OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+    Type originalType) const {
   for (const TargetMaterializationCallbackFn &fn :
-       llvm::reverse(targetMaterializations))
-    if (Value result = fn(builder, resultType, inputs, loc, originalType))
-      return result;
-  return nullptr;
+       llvm::reverse(targetMaterializations)) {
+    SmallVector<Value> result =
+        fn(builder, resultTypes, inputs, loc, originalType);
+    if (result.empty())
+      continue;
+    return result;
+  }
+  return {};
 }
 
 std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
 using namespace llvm;
 using namespace mlir;
 
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
-                                                 Location loc,
-                                                 TypeRange resultTypes,
-                                                 Value input) const {
-  for (const OneToNMaterializationCallbackFn &fn :
-       llvm::reverse(oneToNTargetMaterializations)) {
-    if (std::optional<SmallVector<Value>> result =
-            fn(builder, resultTypes, input, loc))
-      return *result;
-  }
-  return std::nullopt;
-}
-
 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
   TypeRange convertedTypes = getConvertedTypes();
   if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
 LogicalResult
 OneToNConversionPattern::matchAndRewrite(Operation *op,
                                          PatternRewriter &rewriter) const {
-  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+  auto *typeConverter = getTypeConverter();
 
   // Construct conversion mapping for results.
   Operation::result_type_range originalResultTypes = op->getResultTypes();
   OneToNTypeMapping resultMapping(originalResultTypes);
-  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
-                                               resultMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+                                                 resultMapping)))
     return failure();
 
   // Construct conversion mapping for operands.
   Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
   OneToNTypeMapping operandMapping(originalOperandTypes);
-  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
-                                               operandMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+                                                 operandMapping)))
     return failure();
 
   // Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
 // inserted by this pass are annotated with a string attribute that also
 // documents which kind of the cast (source, argument, or target).
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns) {
 #ifndef NDEBUG
   // Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
       // Target materialization.
       assert(!areOperandTypesLegal && areResultsTypesLegal &&
              operands.size() == 1 && "found unexpected target cast");
-      std::optional<SmallVector<Value>> maybeResults =
-          typeConverter.materializeTargetConversion(
-              rewriter, castOp->getLoc(), resultTypes, operands.front());
-      if (!maybeResults) {
+      materializedResults = typeConverter.materializeTargetConversion(
+          rewriter, castOp->getLoc(), resultTypes, operands.front());
+      if (materializedResults.empty()) {
         emitError(castOp->getLoc())
             << "failed to create target materialization";
         return failure();
       }
-      materializedResults = maybeResults.value();
     } else {
       // Source and argument materializations.
       assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
                                 const OneToNTypeMapping &resultMapping,
                                 ValueRange convertedOperands) const override {
     auto funcOp = cast<FunctionOpInterface>(op);
-    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+    auto *typeConverter = getTypeConverter();
 
     // Construct mapping for function arguments.
     OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
-                                                 argumentMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+                                                   argumentMapping)))
       return failure();
 
     // Construct mapping for function results.
     OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
-                                                 funcResultMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+                                                   funcResultMapping)))
       return failure();
 
     // Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
 ///
 /// This function has been copied (with small adaptions) from
 /// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
-                        Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+                                                  TypeRange resultTypes,
+                                                  ValueRange inputs,
+                                                  Location loc) {
+  if (inputs.size() != 1)
+    return {};
+  Value input = inputs.front();
+
   TupleType inputType = dyn_cast<TupleType>(input.getType());
   if (!inputType)
     return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   auto *context = &getContext();
 
   // Assemble type converter.
-  OneToNTypeConverter typeConverter;
+  TypeConverter typeConverter;
 
   typeConverter.addConversion([](Type type) { return type; });
   typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   typeConverter.addArgumentMaterialization(buildMakeTupleOp);
   typeConverter.addSourceMaterialization(buildMakeTupleOp);
   typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+  // Test the other target materialization variant that takes the original type
+  // as additional argument. This materialization function always fails.
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+         Location loc, Type originalType) -> SmallVector<Value> { return {}; });
 
   // Assemble patterns.
   RewritePatternSet patterns(context);

@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2024

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures.

1:1 target materializations (producing a single Value) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a SmallVector&lt;Value&gt;). Internally, all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the OneToNTypeConverter, simply switch all occurrences to TypeConverter.

Depends on #113031.


Full diff: https://github.com/llvm/llvm-project/pull/113032.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+41-13)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+1-44)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-4)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+14-30)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+14-4)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
 //===----------------------------------------------------------------------===//
 
 /// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
   SparseIterationTypeConverter();
 };
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..eb7da67c1bb995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
   /// conversion has finished.
   ///
   /// Note: Target materializations may optionally accept an additional Type
-  /// parameter, which is the original type of the SSA value.
+  /// parameter, which is the original type of the SSA value. Furthermore `T`
+  /// can be a TypeRange; in that case, the function must return a
+  /// SmallVector<Value>.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
   /// will be invoked with: outputType = "t3", inputs = "v2",
   // originalType = "t1". Note  that the original type "t1" cannot be recovered
   /// from just "t3" and "v2"; that's why the originalType parameter exists.
+  ///
+  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+  /// that case the materialization produces a SmallVector<Value>.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
                                     Type resultType, ValueRange inputs,
                                     Type originalType = {}) const;
+  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc,
+                                                 TypeRange resultType,
+                                                 ValueRange inputs,
+                                                 Type originalType = {}) const;
 
   /// Convert an attribute present `attr` from within the type `type` using
   /// the registered conversion functions. If no applicable conversion has been
@@ -341,8 +351,8 @@ class TypeConverter {
   /// The signature of the callback used to materialize a target conversion.
   ///
   /// Arguments: builder, result type, inputs, location, original type
-  using TargetMaterializationCallbackFn =
-      std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+      OpBuilder &, TypeRange, ValueRange, Location, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
   /// callback.
   ///
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
       TargetMaterializationCallbackFn>
   wrapTargetMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc, Type originalType) -> Value {
-      if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc, originalType);
-      return Value();
+               OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+               Location loc, Type originalType) -> SmallVector<Value> {
+      SmallVector<Value> result;
+      if constexpr (std::is_same<T, TypeRange>::value) {
+        // This is a 1:N target materialization. Return the produces values
+        // directly.
+        result = callback(builder, resultTypes, inputs, loc, originalType);
+      } else {
+        // This is a 1:1 target materialization. Invoke it only if the result
+        // type class of the callback matches the requested result type.
+        if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+          // 1:1 materializations produce single values, but we store 1:N
+          // target materialization functions in the type converter. Wrap the
+          // result value in a SmallVector<Value>.
+          std::optional<Value> val =
+              callback(builder, derivedType, inputs, loc, originalType);
+          if (val.has_value() && *val)
+            result.push_back(*val);
+        }
+      }
+      return result;
     };
   }
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location)`
+  /// - Value(OpBuilder &, T, ValueRange, Location)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
   wrapTargetMaterialization(FnT &&callback) const {
     return wrapTargetMaterialization<T>(
         [callback = std::forward<FnT>(callback)](
-            OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
-            Type originalType) -> Value {
-          return callback(builder, resultType, inputs, loc);
+            OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+            Type originalType) {
+          return callback(builder, resultTypes, inputs, loc);
         });
   }
 
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
 
 namespace mlir {
 
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
-  /// Callback that expresses user-provided materialization logic from the given
-  /// value to N values of the given types. This is useful for expressing target
-  /// materializations for 1:N type conversions, which materialize one value in
-  /// a source type as N values in target types.
-  using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
-
-  /// Creates the mapping of the given range of original types to target types
-  /// of the conversion and stores that mapping in the given (signature)
-  /// conversion. This function simply calls
-  /// `TypeConverter::convertSignatureArgs` and exists here with a different
-  /// name to reflect the broader semantic.
-  LogicalResult computeTypeMapping(TypeRange types,
-                                   SignatureConversion &result) const {
-    return convertSignatureArgs(types, result);
-  }
-
-  /// Applies one of the user-provided 1:N target materializations. If several
-  /// exists, they are tried out in the reverse order in which they have been
-  /// added until the first one succeeds. If none succeeds, the functions
-  /// returns `std::nullopt`.
-  std::optional<SmallVector<Value>>
-  materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
-
-  /// Adds a 1:N target materialization to the converter. Such materializations
-  /// build IR that converts N values with target types into 1 value of the
-  /// source type.
-  void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
-    oneToNTargetMaterializations.emplace_back(std::move(callback));
-  }
-
-private:
-  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
 /// Stores a 1:N mapping of types and provides several useful accessors. This
 /// class extends `SignatureConversion`, which already supports 1:N type
 /// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
 /// not fail if some ops or types remain unconverted (i.e., the conversion is
 /// only "partial").
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns);
 
 /// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
-    OneToNTypeConverter converter;
+    TypeConverter converter;
     RewritePatternSet patterns(context);
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
                                                  Location loc, Type resultType,
                                                  ValueRange inputs,
                                                  Type originalType) const {
+  SmallVector<Value> result = materializeTargetConversion(
+      builder, loc, TypeRange(resultType), inputs, originalType);
+  if (result.empty())
+    return nullptr;
+  assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+                               "produced 1:N materialization");
+  return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+    OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+    Type originalType) const {
   for (const TargetMaterializationCallbackFn &fn :
-       llvm::reverse(targetMaterializations))
-    if (Value result = fn(builder, resultType, inputs, loc, originalType))
-      return result;
-  return nullptr;
+       llvm::reverse(targetMaterializations)) {
+    SmallVector<Value> result =
+        fn(builder, resultTypes, inputs, loc, originalType);
+    if (result.empty())
+      continue;
+    return result;
+  }
+  return {};
 }
 
 std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
 using namespace llvm;
 using namespace mlir;
 
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
-                                                 Location loc,
-                                                 TypeRange resultTypes,
-                                                 Value input) const {
-  for (const OneToNMaterializationCallbackFn &fn :
-       llvm::reverse(oneToNTargetMaterializations)) {
-    if (std::optional<SmallVector<Value>> result =
-            fn(builder, resultTypes, input, loc))
-      return *result;
-  }
-  return std::nullopt;
-}
-
 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
   TypeRange convertedTypes = getConvertedTypes();
   if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
 LogicalResult
 OneToNConversionPattern::matchAndRewrite(Operation *op,
                                          PatternRewriter &rewriter) const {
-  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+  auto *typeConverter = getTypeConverter();
 
   // Construct conversion mapping for results.
   Operation::result_type_range originalResultTypes = op->getResultTypes();
   OneToNTypeMapping resultMapping(originalResultTypes);
-  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
-                                               resultMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+                                                 resultMapping)))
     return failure();
 
   // Construct conversion mapping for operands.
   Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
   OneToNTypeMapping operandMapping(originalOperandTypes);
-  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
-                                               operandMapping)))
+  if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+                                                 operandMapping)))
     return failure();
 
   // Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
 // inserted by this pass are annotated with a string attribute that also
 // documents which kind of the cast (source, argument, or target).
 LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns) {
 #ifndef NDEBUG
   // Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
       // Target materialization.
       assert(!areOperandTypesLegal && areResultsTypesLegal &&
              operands.size() == 1 && "found unexpected target cast");
-      std::optional<SmallVector<Value>> maybeResults =
-          typeConverter.materializeTargetConversion(
-              rewriter, castOp->getLoc(), resultTypes, operands.front());
-      if (!maybeResults) {
+      materializedResults = typeConverter.materializeTargetConversion(
+          rewriter, castOp->getLoc(), resultTypes, operands.front());
+      if (materializedResults.empty()) {
         emitError(castOp->getLoc())
             << "failed to create target materialization";
         return failure();
       }
-      materializedResults = maybeResults.value();
     } else {
       // Source and argument materializations.
       assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
                                 const OneToNTypeMapping &resultMapping,
                                 ValueRange convertedOperands) const override {
     auto funcOp = cast<FunctionOpInterface>(op);
-    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+    auto *typeConverter = getTypeConverter();
 
     // Construct mapping for function arguments.
     OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
-                                                 argumentMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+                                                   argumentMapping)))
       return failure();
 
     // Construct mapping for function results.
     OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
-    if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
-                                                 funcResultMapping)))
+    if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+                                                   funcResultMapping)))
       return failure();
 
     // Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
 ///
 /// This function has been copied (with small adaptions) from
 /// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
-                        Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+                                                  TypeRange resultTypes,
+                                                  ValueRange inputs,
+                                                  Location loc) {
+  if (inputs.size() != 1)
+    return {};
+  Value input = inputs.front();
+
   TupleType inputType = dyn_cast<TupleType>(input.getType());
   if (!inputType)
     return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   auto *context = &getContext();
 
   // Assemble type converter.
-  OneToNTypeConverter typeConverter;
+  TypeConverter typeConverter;
 
   typeConverter.addConversion([](Type type) { return type; });
   typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
   typeConverter.addArgumentMaterialization(buildMakeTupleOp);
   typeConverter.addSourceMaterialization(buildMakeTupleOp);
   typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+  // Test the other target materialization variant that takes the original type
+  // as additional argument. This materialization function always fails.
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+         Location loc, Type originalType) -> SmallVector<Value> { return {}; });
 
   // Assemble patterns.
   RewritePatternSet patterns(context);

@matthias-springer matthias-springer force-pushed the users/matthias-springer/merge_type_converters branch from 8fc8147 to 7ac8aa2 Compare October 19, 2024 10:47
@matthias-springer matthias-springer force-pushed the users/matthias-springer/mat_remove_optional branch from 2fe3e2d to 59a9bbb Compare October 23, 2024 14:16
Base automatically changed from users/matthias-springer/mat_remove_optional to main October 23, 2024 14:29
@matthias-springer matthias-springer force-pushed the users/matthias-springer/merge_type_converters branch 2 times, most recently from 216d522 to 44db7a9 Compare October 23, 2024 16:55
@matthias-springer matthias-springer force-pushed the users/matthias-springer/merge_type_converters branch from eaff670 to 43edbab Compare October 25, 2024 17:52
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer merged commit 8c4bc1e into main Oct 25, 2024
6 of 7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/merge_type_converters branch October 25, 2024 18:44
@llvm-ci
Copy link
Collaborator

llvm-ci commented Oct 25, 2024

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building mlir at step 5 "build-check-mlir-build-only".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/5467

Here is the relevant piece of the build log for the reference
Step 5 (build-check-mlir-build-only) failure: build (failure)
...
364.284 [638/16/3937] Linking CXX executable tools/mlir/unittests/Dialect/SparseTensor/MLIRSparseTensorTests
364.293 [637/16/3938] Building CXX object tools/mlir/unittests/Interfaces/CMakeFiles/MLIRInterfacesTests.dir/InferIntRangeInterfaceTest.cpp.o
364.314 [636/16/3939] Building CXX object tools/mlir/unittests/Interfaces/CMakeFiles/MLIRInterfacesTests.dir/InferTypeOpInterfaceTest.cpp.o
364.321 [635/16/3940] Building CXX object tools/mlir/unittests/Pass/CMakeFiles/MLIRPassTests.dir/PassPipelineParserTest.cpp.o
364.350 [634/16/3941] Building CXX object tools/mlir/unittests/Pass/CMakeFiles/MLIRPassTests.dir/AnalysisManagerTest.cpp.o
364.378 [633/16/3942] Building CXX object tools/mlir/unittests/Pass/CMakeFiles/MLIRPassTests.dir/PassManagerTest.cpp.o
364.588 [632/16/3943] Linking CXX executable tools/mlir/unittests/Interfaces/MLIRInterfacesTests
364.589 [631/16/3944] Linking CXX executable tools/mlir/unittests/Pass/MLIRPassTests
364.618 [630/16/3945] Building CXX object tools/mlir/unittests/Rewrite/CMakeFiles/MLIRRewriteTests.dir/PatternBenefit.cpp.o
365.215 [629/16/3946] Building CXX object tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o
FAILED: tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o 
CCACHE_CPP2=yes CCACHE_HASHDIR=yes /usr/bin/ccache /usr/bin/g++-7 -DGTEST_HAS_RTTI=0 -D_DEBUG -D_GLIBCXX_ASSERTIONS -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/tools/mlir/lib/Transforms/Utils -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/include -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include -I/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/tools/mlir/include -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -fno-lifetime-dse -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wno-missing-field-initializers -pedantic -Wno-long-long -Wimplicit-fallthrough -Wno-uninitialized -Wno-nonnull -Wno-noexcept-type -Wdelete-non-virtual-dtor -Wno-comment -Wno-misleading-indentation -fdiagnostics-color -ffunction-sections -fdata-sections -Wundef -Wno-unused-but-set-parameter -Wno-deprecated-copy -O3 -DNDEBUG  -fno-exceptions -funwind-tables -fno-rtti -UNDEBUG -std=c++1z -MD -MT tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o -MF tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o.d -o tools/mlir/lib/Transforms/Utils/CMakeFiles/obj.MLIRTransformUtils.dir/DialectConversion.cpp.o -c /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp
In file included from /usr/include/c++/7/cassert:44:0,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include/llvm/Support/Error.h:26,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include/llvm/Support/JSON.h:54,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include/llvm/Support/ScopedPrinter.h:19,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp:24:
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp: In member function ‘llvm::SmallVector<mlir::Value> mlir::TypeConverter::materializeTargetConversion(mlir::OpBuilder&, mlir::Location, mlir::TypeRange, mlir::ValueRange, mlir::Type) const’:
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp:2851:28: error: call of overloaded ‘TypeRange(llvm::SmallVector<mlir::Value>&)’ is ambiguous
     assert(TypeRange(result) == resultTypes &&
                            ^
In file included from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Support/TypeID.h:20:0,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/MLIRContext.h:13,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/DialectRegistry.h:16,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/Dialect.h:16,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/OpDefinition.h:22,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/Builders.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/PatternMatch.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Transforms/DialectConversion.h:17,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp:9:
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/llvm/include/llvm/ADT/STLExtras.h:1278:3: note: candidate: llvm::detail::indexed_accessor_range_base<DerivedT, BaseT, T, PointerT, ReferenceT>::indexed_accessor_range_base(const llvm::iterator_range<llvm::detail::indexed_accessor_range_base<DerivedT, BaseT, T, PointerT, ReferenceT>::iterator>&) [with DerivedT = mlir::TypeRange; BaseT = llvm::PointerUnion<const mlir::Value*, const mlir::Type*, mlir::OpOperand*, mlir::detail::OpResultImpl*>; T = mlir::Type; PointerT = mlir::Type; ReferenceT = mlir::Type]
   indexed_accessor_range_base(const iterator_range<iterator> &range)
   ^~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/OperationSupport.h:23:0,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/Dialect.h:17,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/OpDefinition.h:22,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/Builders.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/PatternMatch.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h:12,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/Transforms/DialectConversion.h:17,
                 from /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp:9:
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/TypeRange.h:38:21: note:   inherited here
   using RangeBaseT::RangeBaseT;
                     ^~~~~~~~~~
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/include/mlir/IR/TypeRange.h:42:12: note: candidate: mlir::TypeRange::TypeRange(mlir::ValueRange)
   explicit TypeRange(ValueRange values);
            ^~~~~~~~~
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/Transforms/Utils/DialectConversion.cpp: At global scope:

Copy link
Contributor

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Nice progress!

matthias-springer added a commit that referenced this pull request Oct 30, 2024
The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032,tThe dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)
matthias-springer added a commit that referenced this pull request Oct 30, 2024
The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032,tThe dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)
matthias-springer added a commit that referenced this pull request Oct 30, 2024
…114192)

The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround
around missing 1:N support in the dialect conversion. Since #113032, the
dialect conversion infrastructure supports 1:N type conversions and 1:N
target materializations. The `ValueDecomposer` class is no longer
needed. (However, target materializations must still be inserted
manually, until we fully merge the 1:1 and 1:N drivers.)

Note for LLVM integration: Register 1:N target materializations on the
type converter instead of "decompose value conversions" on the
`ValueDecomposer`.
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
…lvm#114192)

The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround
around missing 1:N support in the dialect conversion. Since llvm#113032, the
dialect conversion infrastructure supports 1:N type conversions and 1:N
target materializations. The `ValueDecomposer` class is no longer
needed. (However, target materializations must still be inserted
manually, until we fully merge the 1:1 and 1:N drivers.)

Note for LLVM integration: Register 1:N target materializations on the
type converter instead of "decompose value conversions" on the
`ValueDecomposer`.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
The 1:N type converter derived from the 1:1 type converter and extends
it with 1:N target materializations. This commit merges the two type
converters and stores 1:N target materializations in the 1:1 type
converter. This is in preparation of merging the 1:1 and 1:N dialect
conversion infrastructures.

1:1 target materializations (producing a single `Value`) will remain
valid. An additional API is added to the type converter to register 1:N
target materializations (producing a `SmallVector<Value>`). Internally,
all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the `OneToNTypeConverter`,
simply switch all occurrences to `TypeConverter`.

---------

Co-authored-by: Markus Böck <[email protected]>
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
…lvm#114192)

The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround
around missing 1:N support in the dialect conversion. Since llvm#113032, the
dialect conversion infrastructure supports 1:N type conversions and 1:N
target materializations. The `ValueDecomposer` class is no longer
needed. (However, target materializations must still be inserted
manually, until we fully merge the 1:1 and 1:N drivers.)

Note for LLVM integration: Register 1:N target materializations on the
type converter instead of "decompose value conversions" on the
`ValueDecomposer`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:sme mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants