Skip to content

[mlir][Transforms] Add 1:N matchAndRewrite overload #116470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2024

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 16, 2024

This commit adds a new matchAndRewrite overload to ConversionPattern to support 1:N replacements. This is the first of two main PRs that merge the 1:1 and 1:N dialect conversion drivers.

The existing matchAndRewrite function supports only 1:1 replacements, as can be seen from the ArrayRef<Value> parameter.

LogicalResult ConversionPattern::matchAndRewrite(
    Operation *op, ArrayRef<Value> operands /*adaptor values*/,
    ConversionPatternRewriter &rewriter) const;

This commit adds a matchAndRewrite overload that is called by the dialect conversion driver. By default, this new overload dispatches to the original 1:1 matchAndRewrite implementation. Existing ConversionPatterns do not need to be changed as long as there are no 1:N type conversions or value replacements.

LogicalResult ConversionPattern::matchAndRewrite(
    Operation *op, ArrayRef<ValueRange> operands /*adaptor values*/,
    ConversionPatternRewriter &rewriter) const {
  // Note: getOneToOneAdaptorOperands produces a fatal error if at least one
  // ValueRange has 0 or more than 1 value.
  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

The ConversionValueMapping, which keeps track of value replacements and materializations, still does not support 1:N replacements. We still rely on argument materializations to convert N replacement values back into a single value. The ConversionValueMapping will be generalized to 1:N mappings in the second main PR.

Before handing the adaptor values to a ConversionPattern, all argument materializations are "unpacked". The ConversionPattern receives N replacement values and does not see any argument materializations. This implementation strategy allows us to use the 1:N infrastructure/API in ConversionPatterns even though some functionality is still missing in the driver. This strategy was chosen to keep the sizes of the PRs smaller and to make it easier for downstream users to adapt to API changes.

This commit also updates the the "decompose call graphs" transformation and the "sparse tensor codegen" transformation to use the new 1:N ConversionPattern API.

Note for LLVM conversion: If you are using a type converter with 1:N type conversion rules or if your patterns are performing 1:N replacements (via replaceOpWithMultiple or applySignatureConversion), conversion pattern applications will start failing (fatal LLVM error) with this error message: pattern 'name' does not support 1:N conversion. The name of the failing pattern is shown in the error message. These patterns must be updated to the new 1:N matchAndRewrite API.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/1n_pattern branch 4 times, most recently from fe38d4b to 153310a Compare November 16, 2024 08:44
@matthias-springer matthias-springer marked this pull request as ready for review November 16, 2024 08:52
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:llvm mlir:sparse Sparse compiler in MLIR mlir mlir:scf mlir:func labels Nov 16, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 16, 2024

@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir-func
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new matchAndRewrite overload to ConversionPattern to support 1:N replacements. This is the first of two main PRs that merge the 1:1 and 1:N dialect conversion drivers.

The existing matchAndRewrite function supports only 1:1 replacements, as can be seen from the ArrayRef&lt;Value&gt; parameter.

LogicalResult ConversionPattern::matchAndRewrite(
    Operation *op, ArrayRef&lt;Value&gt; operands /*adaptor values*/,
    ConversionPatternRewriter &amp;rewriter) const;

This commit adds a matchAndRewrite overload that is called by the dialect conversion driver. By default, this new overload dispatches to the original 1:1 matchAndRewrite implementation.

LogicalResult ConversionPattern::matchAndRewrite(
    Operation *op, ArrayRef&lt;ValueRange&gt; operands /*adaptor values*/,
    ConversionPatternRewriter &amp;rewriter) const {
  // Note: getOneToOneAdaptorOperands produces a fatal error if at least one
  // ValueRange has 0 or more than 1 value.
  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

The ConversionValueMapping, which keeps track of value replacements and materializations, still does not support 1:N replacements. We still rely on argument materializations to convert N replacement values back into a single value. The ConversionValueMapping will be generalized to 1:N mappings in the second main PR.

Before handing the adaptor values to a ConversionPattern, all argument materializations are "unpacked". The ConversionPattern receives N replacement values and does not see any argument materializations. This implementation strategy allows us to use the 1:N infrastructure/API in ConversionPatterns even though some functionality is still missing in the driver. This strategy was chosen to keep the sizes of the PRs smaller and to make it easier for downstream users to adapt to API changes.

This commit also updates the the "decompose call graphs" transformation and the "sparse tensor codegen" transformation to use the new 1:N ConversionPattern API.

Note for LLVM conversion: If you are using a type converter with 1:N type conversion rules or if your patterns are performing 1:N replacements (via replaceOpWithMultiple or applySignatureConversion), conversion pattern applications will start failing (fatal LLVM error) with this error message: pattern 'name' does not support 1:N conversion. The name of the failing pattern is shown in the error message. These patterns must be updated to the new 1:N matchAndRewrite API.


Patch is 70.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116470.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+31-4)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+63)
  • (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+6-50)
  • (modified) mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp (+3-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+40-66)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+54-60)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h (+5-11)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+173-78)
  • (modified) mlir/test/Transforms/decompose-call-graph-types.mlir (+6-32)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f3bf5b66398e09..86ea87b55af1cd 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -143,6 +143,8 @@ template <typename SourceOp>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
+  using OneToNOpAdaptor =
+      typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
 
   explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
                                   PatternBenefit benefit = 1)
@@ -153,8 +155,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   /// Wrappers around the RewritePattern methods that pass the derived op type.
   void rewrite(Operation *op, ArrayRef<Value> operands,
                ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
-            rewriter);
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+  }
+  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
   }
   LogicalResult match(Operation *op) const final {
     return match(cast<SourceOp>(op));
@@ -162,8 +169,15 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    return matchAndRewrite(cast<SourceOp>(op),
-                           OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+  }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+                           rewriter);
   }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override rewrite or matchAndRewrite");
   }
+  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
     rewrite(op, adaptor, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
 
 private:
   using ConvertToLLVMPattern::match;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index de47765006f81e..e4eeb39b9c0741 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("unimplemented rewrite");
   }
+  virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
   /// Hook for derived classes to implement combined matching and rewriting.
   virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
     rewrite(op, operands, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
   /// Attempt to match and rewrite the IR root at the specified operation.
   LogicalResult matchAndRewrite(Operation *op,
@@ -574,6 +583,15 @@ class ConversionPattern : public RewritePattern {
       : RewritePattern(std::forward<Args>(args)...),
         typeConverter(&typeConverter) {}
 
+  /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
+  /// try to extract the single value of each range to construct a the inputs
+  /// for a 1:1 adaptor.
+  ///
+  /// This function produces a fatal error if at least one range has 0 or
+  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
+  SmallVector<Value>
+  getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
+
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
@@ -589,6 +607,8 @@ template <typename SourceOp>
 class OpConversionPattern : public ConversionPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
+  using OneToNOpAdaptor =
+      typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
 
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +627,24 @@ class OpConversionPattern : public ConversionPattern {
     auto sourceOp = cast<SourceOp>(op);
     rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
+  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+                           rewriter);
+  }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
@@ -623,6 +655,12 @@ class OpConversionPattern : public ConversionPattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override matchAndRewrite or a rewrite method");
   }
+  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
@@ -631,6 +669,13 @@ class OpConversionPattern : public ConversionPattern {
     rewrite(op, adaptor, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
 
 private:
   using ConversionPattern::matchAndRewrite;
@@ -656,11 +701,20 @@ class OpInterfaceConversionPattern : public ConversionPattern {
                ConversionPatternRewriter &rewriter) const final {
     rewrite(cast<SourceOp>(op), operands, rewriter);
   }
+  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    rewrite(cast<SourceOp>(op), operands, rewriter);
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
   }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+  }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
@@ -668,6 +722,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override matchAndRewrite or a rewrite method");
   }
+  virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
@@ -676,6 +734,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
     rewrite(op, operands, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
 private:
   using ConversionPattern::matchAndRewrite;
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index a08764326a80b6..03be00328bda33 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -13,40 +13,6 @@
 using namespace mlir;
 using namespace mlir::func;
 
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// If the given value can be decomposed with the type converter, decompose it.
-/// Otherwise, return the given value.
-// TODO: Value decomposition should happen automatically through a 1:N adaptor.
-// This function will disappear when the 1:1 and 1:N drivers are merged.
-static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
-                                         Value value,
-                                         const TypeConverter *converter) {
-  // Try to convert the given value's type. If that fails, just return the
-  // given value.
-  SmallVector<Type> convertedTypes;
-  if (failed(converter->convertType(value.getType(), convertedTypes)))
-    return {value};
-  if (convertedTypes.empty())
-    return {};
-
-  // If the given value's type is already legal, just return the given value.
-  TypeRange convertedTypeRange(convertedTypes);
-  if (convertedTypeRange == TypeRange(value.getType()))
-    return {value};
-
-  // Try to materialize a target conversion. If the materialization did not
-  // produce values of the requested type, the materialization failed. Just
-  // return the given value in that case.
-  SmallVector<Value> result = converter->materializeTargetConversion(
-      builder, loc, convertedTypeRange, value);
-  if (result.empty())
-    return {value};
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // DecomposeCallGraphTypesForFuncArgs
 //===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands()) {
-      // TODO: We can directly take the values from the adaptor once this is a
-      // 1:N conversion pattern.
-      llvm::append_range(newOperands,
-                         decomposeValue(rewriter, operand.getLoc(), operand,
-                                        getTypeConverter()));
-    }
+    for (ValueRange operand : adaptor.getOperands())
+      llvm::append_range(newOperands, operand);
     rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
     return success();
   }
@@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CallOp op, OpAdaptor adaptor,
+  matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
 
     // Create the operands list of the new `CallOp`.
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands()) {
-      // TODO: We can directly take the values from the adaptor once this is a
-      // 1:N conversion pattern.
-      llvm::append_range(newOperands,
-                         decomposeValue(rewriter, operand.getLoc(), operand,
-                                        getTypeConverter()));
-    }
+    for (ValueRange operand : adaptor.getOperands())
+      llvm::append_range(newOperands, operand);
 
     // Create the new result types for the new `CallOp` and track the number of
     // replacement types for each original op result.
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index eb444d665ff260..d81f822f7d4b51 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -21,7 +21,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
 
   /// Hook for derived classes to implement combined matching and rewriting.
   LogicalResult
-  matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
+  matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Convert the original function results.
     SmallVector<Type, 1> convertedResults;
@@ -37,7 +37,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
     // Substitute with the new result types from the corresponding FuncType
     // conversion.
     rewriter.replaceOpWithNewOp<CallOp>(
-        callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
+        callOp, callOp.getCallee(), convertedResults,
+        getOneToOneAdaptorOperands(adaptor.getOperands()));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 93a78056db1944..c0589044c26ecb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -16,20 +16,18 @@ using namespace mlir::scf;
 
 namespace {
 
-// Unpacks the single unrealized_conversion_cast using the list of inputs
-// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
-static void unpackUnrealizedConversionCast(Value v,
-                                           SmallVectorImpl<Value> &unpacked) {
-  if (auto cast =
-          dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
-    if (cast.getInputs().size() != 1) {
-      // 1 : N type conversion.
-      unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
-      return;
-    }
-  }
-  // 1 : 1 type conversion.
-  unpacked.push_back(v);
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+  assert(values.size() == 1 && "expected single value");
+  return values.front();
 }
 
 // CRTP
@@ -40,19 +38,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
 public:
   using OpConversionPattern<SourceOp>::typeConverter;
   using OpConversionPattern<SourceOp>::OpConversionPattern;
-  using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
+  using OneToNOpAdaptor =
+      typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
 
   //
   // Derived classes should provide the following method which performs the
   // actual conversion. It should return std::nullopt upon conversion failure
   // and return the converted operation upon success.
   //
-  // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
-  //                                    ConversionPatternRewriter &rewriter,
-  //                                    TypeRange dstTypes) const;
+  // std::optional<SourceOp> convertSourceOp(
+  //     SourceOp op, OneToNOpAdaptor adaptor,
+  //     ConversionPatternRewriter &rewriter,
+  //     TypeRange dstTypes) const;
 
   LogicalResult
-  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SmallVector<Type> dstTypes;
     SmallVector<unsigned> offsets;
@@ -73,28 +73,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
       return rewriter.notifyMatchFailure(op, "could not convert operation");
 
     // Packs the return value.
-    SmallVector<Value> packedRets;
+    SmallVector<ValueRange> packedRets;
     for (unsigned i = 1, e = offsets.size(); i < e; i++) {
       unsigned start = offsets[i - 1], end = offsets[i];
       unsigned len = end - start;
       ValueRange mappedValue = newOp->getResults().slice(start, len);
-      if (len != 1) {
-        // 1 : N type conversion.
-        Type origType = op.getResultTypes()[i - 1];
-        Value mat = typeConverter->materializeSourceConversion(
-            rewriter, op.getLoc(), origType, mappedValue);
-        if (!mat) {
-          return rewriter.notifyMatchFailure(
-              op, "Failed to materialize 1:N type conversion");
-        }
-        packedRets.push_back(mat);
-      } else {
-        // 1 : 1 type conversion.
-        packedRets.push_back(mappedValue.front());
-      }
+      packedRets.push_back(mappedValue);
     }
 
-    rewriter.replaceOp(op, packedRets);
+    rewriter.replaceOpWithMultiple(op, packedRets);
     return success();
   }
 };
@@ -105,7 +92,7 @@ class ConvertForOpTypes
   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
 
   // The callback required by CRTP.
-  std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
+  std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
                                        ConversionPatternRewriter &rewriter,
                                        TypeRange dstTypes) const {
     // Create a empty new op and inline the regions from the old op.
@@ -129,16 +116,13 @@ class ConvertForOpTypes
     if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
       return std::nullopt;
 
-    // Unpacked the iteration arguments.
-    SmallVector<Value> flatArgs;
-    for (Value arg : adaptor.getInitArgs())
-      unpackUnrealizedConversionCast(arg, flatArgs);
-
     // ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 16, 2024

@llvm/pr-subscribers-mlir-sparse

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new matchAndRewrite overload to ConversionPattern to support 1:N replacements. This is the first of two main PRs that merge the 1:1 and 1:N dialect conversion drivers.

The existing matchAndRewrite function supports only 1:1 replacements, as can be seen from the ArrayRef&lt;Value&gt; parameter.

LogicalResult ConversionPattern::matchAndRewrite(
    Operation *op, ArrayRef&lt;Value&gt; operands /*adaptor values*/,
    ConversionPatternRewriter &amp;rewriter) const;

This commit adds a matchAndRewrite overload that is called by the dialect conversion driver. By default, this new overload dispatches to the original 1:1 matchAndRewrite implementation.

LogicalResult ConversionPattern::matchAndRewrite(
    Operation *op, ArrayRef&lt;ValueRange&gt; operands /*adaptor values*/,
    ConversionPatternRewriter &amp;rewriter) const {
  // Note: getOneToOneAdaptorOperands produces a fatal error if at least one
  // ValueRange has 0 or more than 1 value.
  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

The ConversionValueMapping, which keeps track of value replacements and materializations, still does not support 1:N replacements. We still rely on argument materializations to convert N replacement values back into a single value. The ConversionValueMapping will be generalized to 1:N mappings in the second main PR.

Before handing the adaptor values to a ConversionPattern, all argument materializations are "unpacked". The ConversionPattern receives N replacement values and does not see any argument materializations. This implementation strategy allows us to use the 1:N infrastructure/API in ConversionPatterns even though some functionality is still missing in the driver. This strategy was chosen to keep the sizes of the PRs smaller and to make it easier for downstream users to adapt to API changes.

This commit also updates the the "decompose call graphs" transformation and the "sparse tensor codegen" transformation to use the new 1:N ConversionPattern API.

Note for LLVM conversion: If you are using a type converter with 1:N type conversion rules or if your patterns are performing 1:N replacements (via replaceOpWithMultiple or applySignatureConversion), conversion pattern applications will start failing (fatal LLVM error) with this error message: pattern 'name' does not support 1:N conversion. The name of the failing pattern is shown in the error message. These patterns must be updated to the new 1:N matchAndRewrite API.


Patch is 70.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116470.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+31-4)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+63)
  • (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+6-50)
  • (modified) mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp (+3-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+40-66)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+54-60)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h (+5-11)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+173-78)
  • (modified) mlir/test/Transforms/decompose-call-graph-types.mlir (+6-32)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f3bf5b66398e09..86ea87b55af1cd 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -143,6 +143,8 @@ template <typename SourceOp>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
+  using OneToNOpAdaptor =
+      typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
 
   explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
                                   PatternBenefit benefit = 1)
@@ -153,8 +155,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   /// Wrappers around the RewritePattern methods that pass the derived op type.
   void rewrite(Operation *op, ArrayRef<Value> operands,
                ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
-            rewriter);
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+  }
+  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
   }
   LogicalResult match(Operation *op) const final {
     return match(cast<SourceOp>(op));
@@ -162,8 +169,15 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    return matchAndRewrite(cast<SourceOp>(op),
-                           OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+  }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+                           rewriter);
   }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override rewrite or matchAndRewrite");
   }
+  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
     rewrite(op, adaptor, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
 
 private:
   using ConvertToLLVMPattern::match;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index de47765006f81e..e4eeb39b9c0741 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("unimplemented rewrite");
   }
+  virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
   /// Hook for derived classes to implement combined matching and rewriting.
   virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
     rewrite(op, operands, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
   /// Attempt to match and rewrite the IR root at the specified operation.
   LogicalResult matchAndRewrite(Operation *op,
@@ -574,6 +583,15 @@ class ConversionPattern : public RewritePattern {
       : RewritePattern(std::forward<Args>(args)...),
         typeConverter(&typeConverter) {}
 
+  /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
+  /// try to extract the single value of each range to construct a the inputs
+  /// for a 1:1 adaptor.
+  ///
+  /// This function produces a fatal error if at least one range has 0 or
+  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
+  SmallVector<Value>
+  getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
+
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
@@ -589,6 +607,8 @@ template <typename SourceOp>
 class OpConversionPattern : public ConversionPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
+  using OneToNOpAdaptor =
+      typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
 
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +627,24 @@ class OpConversionPattern : public ConversionPattern {
     auto sourceOp = cast<SourceOp>(op);
     rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
+  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+                           rewriter);
+  }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
@@ -623,6 +655,12 @@ class OpConversionPattern : public ConversionPattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override matchAndRewrite or a rewrite method");
   }
+  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
@@ -631,6 +669,13 @@ class OpConversionPattern : public ConversionPattern {
     rewrite(op, adaptor, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
 
 private:
   using ConversionPattern::matchAndRewrite;
@@ -656,11 +701,20 @@ class OpInterfaceConversionPattern : public ConversionPattern {
                ConversionPatternRewriter &rewriter) const final {
     rewrite(cast<SourceOp>(op), operands, rewriter);
   }
+  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    rewrite(cast<SourceOp>(op), operands, rewriter);
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
   }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+  }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
@@ -668,6 +722,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override matchAndRewrite or a rewrite method");
   }
+  virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
@@ -676,6 +734,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
     rewrite(op, operands, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
 private:
   using ConversionPattern::matchAndRewrite;
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index a08764326a80b6..03be00328bda33 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -13,40 +13,6 @@
 using namespace mlir;
 using namespace mlir::func;
 
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// If the given value can be decomposed with the type converter, decompose it.
-/// Otherwise, return the given value.
-// TODO: Value decomposition should happen automatically through a 1:N adaptor.
-// This function will disappear when the 1:1 and 1:N drivers are merged.
-static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
-                                         Value value,
-                                         const TypeConverter *converter) {
-  // Try to convert the given value's type. If that fails, just return the
-  // given value.
-  SmallVector<Type> convertedTypes;
-  if (failed(converter->convertType(value.getType(), convertedTypes)))
-    return {value};
-  if (convertedTypes.empty())
-    return {};
-
-  // If the given value's type is already legal, just return the given value.
-  TypeRange convertedTypeRange(convertedTypes);
-  if (convertedTypeRange == TypeRange(value.getType()))
-    return {value};
-
-  // Try to materialize a target conversion. If the materialization did not
-  // produce values of the requested type, the materialization failed. Just
-  // return the given value in that case.
-  SmallVector<Value> result = converter->materializeTargetConversion(
-      builder, loc, convertedTypeRange, value);
-  if (result.empty())
-    return {value};
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // DecomposeCallGraphTypesForFuncArgs
 //===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands()) {
-      // TODO: We can directly take the values from the adaptor once this is a
-      // 1:N conversion pattern.
-      llvm::append_range(newOperands,
-                         decomposeValue(rewriter, operand.getLoc(), operand,
-                                        getTypeConverter()));
-    }
+    for (ValueRange operand : adaptor.getOperands())
+      llvm::append_range(newOperands, operand);
     rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
     return success();
   }
@@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CallOp op, OpAdaptor adaptor,
+  matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
 
     // Create the operands list of the new `CallOp`.
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands()) {
-      // TODO: We can directly take the values from the adaptor once this is a
-      // 1:N conversion pattern.
-      llvm::append_range(newOperands,
-                         decomposeValue(rewriter, operand.getLoc(), operand,
-                                        getTypeConverter()));
-    }
+    for (ValueRange operand : adaptor.getOperands())
+      llvm::append_range(newOperands, operand);
 
     // Create the new result types for the new `CallOp` and track the number of
     // replacement types for each original op result.
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index eb444d665ff260..d81f822f7d4b51 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -21,7 +21,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
 
   /// Hook for derived classes to implement combined matching and rewriting.
   LogicalResult
-  matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
+  matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Convert the original function results.
     SmallVector<Type, 1> convertedResults;
@@ -37,7 +37,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
     // Substitute with the new result types from the corresponding FuncType
     // conversion.
     rewriter.replaceOpWithNewOp<CallOp>(
-        callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
+        callOp, callOp.getCallee(), convertedResults,
+        getOneToOneAdaptorOperands(adaptor.getOperands()));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 93a78056db1944..c0589044c26ecb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -16,20 +16,18 @@ using namespace mlir::scf;
 
 namespace {
 
-// Unpacks the single unrealized_conversion_cast using the list of inputs
-// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
-static void unpackUnrealizedConversionCast(Value v,
-                                           SmallVectorImpl<Value> &unpacked) {
-  if (auto cast =
-          dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
-    if (cast.getInputs().size() != 1) {
-      // 1 : N type conversion.
-      unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
-      return;
-    }
-  }
-  // 1 : 1 type conversion.
-  unpacked.push_back(v);
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+  assert(values.size() == 1 && "expected single value");
+  return values.front();
 }
 
 // CRTP
@@ -40,19 +38,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
 public:
   using OpConversionPattern<SourceOp>::typeConverter;
   using OpConversionPattern<SourceOp>::OpConversionPattern;
-  using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
+  using OneToNOpAdaptor =
+      typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
 
   //
   // Derived classes should provide the following method which performs the
   // actual conversion. It should return std::nullopt upon conversion failure
   // and return the converted operation upon success.
   //
-  // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
-  //                                    ConversionPatternRewriter &rewriter,
-  //                                    TypeRange dstTypes) const;
+  // std::optional<SourceOp> convertSourceOp(
+  //     SourceOp op, OneToNOpAdaptor adaptor,
+  //     ConversionPatternRewriter &rewriter,
+  //     TypeRange dstTypes) const;
 
   LogicalResult
-  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SmallVector<Type> dstTypes;
     SmallVector<unsigned> offsets;
@@ -73,28 +73,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
       return rewriter.notifyMatchFailure(op, "could not convert operation");
 
     // Packs the return value.
-    SmallVector<Value> packedRets;
+    SmallVector<ValueRange> packedRets;
     for (unsigned i = 1, e = offsets.size(); i < e; i++) {
       unsigned start = offsets[i - 1], end = offsets[i];
       unsigned len = end - start;
       ValueRange mappedValue = newOp->getResults().slice(start, len);
-      if (len != 1) {
-        // 1 : N type conversion.
-        Type origType = op.getResultTypes()[i - 1];
-        Value mat = typeConverter->materializeSourceConversion(
-            rewriter, op.getLoc(), origType, mappedValue);
-        if (!mat) {
-          return rewriter.notifyMatchFailure(
-              op, "Failed to materialize 1:N type conversion");
-        }
-        packedRets.push_back(mat);
-      } else {
-        // 1 : 1 type conversion.
-        packedRets.push_back(mappedValue.front());
-      }
+      packedRets.push_back(mappedValue);
     }
 
-    rewriter.replaceOp(op, packedRets);
+    rewriter.replaceOpWithMultiple(op, packedRets);
     return success();
   }
 };
@@ -105,7 +92,7 @@ class ConvertForOpTypes
   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
 
   // The callback required by CRTP.
-  std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
+  std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
                                        ConversionPatternRewriter &rewriter,
                                        TypeRange dstTypes) const {
     // Create a empty new op and inline the regions from the old op.
@@ -129,16 +116,13 @@ class ConvertForOpTypes
     if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
       return std::nullopt;
 
-    // Unpacked the iteration arguments.
-    SmallVector<Value> flatArgs;
-    for (Value arg : adaptor.getInitArgs())
-      unpackUnrealizedConversionCast(arg, flatArgs);
-
     // ...
[truncated]

Copy link

github-actions bot commented Nov 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Dec 11, 2024
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Dec 11, 2024
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Dec 11, 2024
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Dec 12, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Dec 12, 2024
MaheshRavishankar added a commit to iree-org/iree that referenced this pull request Dec 12, 2024
Carrying the following reverts

- llvm/llvm-project#116470
- llvm/llvm-project#117424
- llvm/llvm-project#119671

First two are carry over from previous integrate. It is being fixed in
#19451 . The last one is a new
failure.

---------

Signed-off-by: MaheshRavishankar <[email protected]>
mikeurbach added a commit to llvm/circt that referenced this pull request Dec 13, 2024
This is not a totally routine bump.

Python support is changing in preparation for a switch to nanobind:

https://discourse.llvm.org/t/psa-python-binding-dependencies-changing/83376

There were also some more DialectConversion changes:

llvm/llvm-project#116470

`InstanceOpConversion`'s 1:1 in FlattenIO generalizes to the 1:N pattern, so we can drop the 1:1 pattern.

Secondly, it seems like the biggest violating thing here is the fact that we were missing a materialization for exploding struct outputs of `hw.module.extern`... This was actually a pretty glaring mistake in the existing IR, which had a check for invalid IR - woops!

Combined with folders, this also seems to have removed some (not all!) of the redundant struct_create/struct_explode patterns in the output.

With the 1:N operand adaptor, we no longer have to manually filter i0 operands inside a conversion pattern. Instead, this information is already implicitly available via the adaptor (i.e. that an operand was removed via. materialization). This also implies that 1:N patterns need to handle the case where the `OneToNOpAdaptor` is empty.

In general, it feels like one would _very_ rarely have to use both the `OneToNOpAdaptor` and `OpAdaptor` overloads at the same time - this should be reserved for when there is truly a difference between 1:1 and 1:N patterns. However, in practice - as this little fixing excercise has demonstrated - most patterns where `OneToNOpAdaptor` is relevant, are generalized 1:N patterns, and doesn't need a specific `1:1` overload.

---------

Co-authored-by: Morten Borup Petersen <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Dec 13, 2024
raikonenfnu pushed a commit to iree-org/llvm-project that referenced this pull request Dec 16, 2024
raikonenfnu added a commit to iree-org/iree that referenced this pull request Dec 17, 2024
Update LLVM to llvm/llvm-project@3f136f7
(#19479)
Carrying the following reverts

- llvm/llvm-project#116470
- llvm/llvm-project#117424
- llvm/llvm-project#119671
- llvm/llvm-project#119970

First two are carry over from previous-previous integrate. It is being
fixed in
#19451 . The last one is a from the
previous integrate.
The last one is a new error being tracked in
#19498

---------

Signed-off-by: Stanley Winata <[email protected]>
raikonenfnu added a commit to raikonenfnu/iree that referenced this pull request Dec 17, 2024
Update LLVM to llvm/llvm-project@b07e7b76c5d532a6 (llvm/llvm-project#120002)
Carrying the following reverts

- llvm/llvm-project#116470
- llvm/llvm-project#117424
- llvm/llvm-project#119671
- llvm/llvm-project#119970

First two are carry over from previous-previous integrate. It is being fixed in
iree-org#19451 . The last one is a from the previous integrate.
The last one is a new error being tracked in iree-org#19498

Signed-off-by: Stanley Winata <[email protected]>
raikonenfnu added a commit to raikonenfnu/iree that referenced this pull request Dec 17, 2024
Update LLVM to llvm/llvm-project@b07e7b76c5d532a6 (llvm/llvm-project#120002)
Carrying the following reverts

- llvm/llvm-project#116470
- llvm/llvm-project#117424
- llvm/llvm-project#119671
- llvm/llvm-project#119970

First two are carry over from previous-previous integrate. It is being fixed in
iree-org#19451 . The last one is a from the previous integrate.
The last one is a new error being tracked in iree-org#19498

Signed-off-by: Stanley Winata <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Dec 17, 2024
MaheshRavishankar added a commit to iree-org/iree that referenced this pull request Dec 17, 2024
…9451)

The upstream change llvm/llvm-project@3f136f7
allows `ConvertToStream` to better handle the 1:N type conversion,
specifically the type conversion of a `tensor<...>` to
`!stream.resource<*>, index`. Now instead of trying to work around
`builtin.unrealized_conversion_cast`s the conversion can get the
converted values directly using the `OneToNAdaptor` and can also replace
a `tensor<..>` directly with multiple values using the
`ConversionPatternRewriter::replaceOpWithMultiple`.

These changes are required to drop the revert of
llvm/llvm-project#116470 in the IREE ToM. The
change drops these reverts as well.

Fixes #19448

---------

Signed-off-by: MaheshRavishankar <[email protected]>
matthias-springer added a commit that referenced this pull request Jan 3, 2025
…116524)

This commit updates the internal `ConversionValueMapping` data structure
in the dialect conversion driver to support 1:N replacements. This is
the last major commit for adding 1:N support to the dialect conversion
driver.

Since #116470, the infrastructure already supports 1:N replacements. But
the `ConversionValueMapping` still stored 1:1 value mappings. To that
end, the driver inserted temporary argument materializations (converting
N SSA values into 1 value). This is no longer the case. Argument
materializations are now entirely gone. (They will be deleted from the
type converter after some time, when we delete the old 1:N dialect
conversion driver.)

Note for LLVM integration: Replace all occurrences of
`addArgumentMaterialization` (except for 1:N dialect conversion passes)
with `addSourceMaterialization`.

---------

Co-authored-by: Markus Böck <[email protected]>
antiagainst pushed a commit to triton-lang/triton that referenced this pull request Jan 31, 2025
### TL;DR (too long, didn't review)

This PR re-enables the `tritonamdgpu-canonicalize-pointers` pass[^1].
The PR is effectively a complete rewrite of the original pass, which
walked the AST and mutated IR in-place, using the new [`1:N` dialect
conversion framework](llvm/llvm-project#116470).
Recall a "fat pointer" is a tuple-like `(%baseptr, %offsetptr)` - the
current (original) pass keeps this tuple in a global data structure
while the new/rewritten pass emits this tuple into the IR as an
`unrealized_cast(%baseptr, %offsetptr)`[^2].

Note, this PR also rewrites the existing lit test (see [this comment
below](#5329 (comment))).

### Pass outline

The pass structure/action is roughly:

1. Perform an approximate sparse dataflow analysis to find all
transitive uses for `tt.func` args that are `tt.ptr`s; legalize only
these ops;
2. Rewrite all operations' `use`s and `result`s to be `(%baseptr,
%offsetptr)` using `ConversionPattern`s that takes the new
`OneToNOpAdaptor`, which automatically forwards both `%baseptr` and
`%offsetptr` through `adaptor.getOperands()`[^3];
3. Clean up remaining `unrealized_casts` (currently only handling one
category of such remaining casts but can be extended to handle all; see
bullet 1 in TODOs).

### Some pre-emptive call outs

Right up front I'll say this took a long time to figure out because
**a)** the conversion framework is hugely complex **b)** it's being
currently rewritten to be more robust/stable. As a consequence, the
implementation is complex but I've tried hard to **a)** simplify as much
as possible **b)** comment/note subtleties **c)** put in ample `assert`s
and checks to clarify intent and gracefully fail. So some things to call
out:

1. I called the dataflow analysis approximate because it does not
actually use
[DataFlow/SparseAnalysis](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp)
and instead computes a forward slice using the heuristic "transfer
function" that users of an op with a `tt.ptr` operand should be
rewritten. This heuristic works because the forward slice starts from
`tt.ptr` args on a `tt.func` and ends at `tt.store`, which has no
results. Note, there's no reason why this component of the pass can't be
a true `SparseAnalysis` implementation, it's just that this rewrite has
already taken way longer than I expected (so I leave that for a possible
follow-up).
2. The pass uses no global `TypeConverter` but uses local
`TypeConverter`s, in `BranchInterface`/`RegionInterface` patterns. This
is because **a)** we are not actually converting operand/result types
(we are converting number of operands/results) **b)** the conversion
framework expects/handles this lack of a `TypeConverter` exactly [the
way we
want](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1179-L1185).
The local type converters are used for ops that couple to basic blocks
(`^bb`s) that need to have their signatures rewritten (i.e., the ops for
which we need to do `rewriter.applySignatureConversion(block,
*conversion, &localTypeConverter)`). That's `scf.for`, `scf.while`,
`cf.br` and `cf.cond_br` (not needed for `scf.if` which has no `bb`
args).
3. `tt.func` is handled differently from all of the other ops - it is
not rewritten at all. Instead, for every `%arg: tt.ptr` arg, we insert
into the new body `%c0 = arith.constant 0 : i32` and `%newarg =
unrealized_cast(%arg, %c0) : tt.ptr` (manually, not done by the
conversion framework) and replace all uses of `%arg` by `%newarg`. These
are then unpacked to `(%arg, %c0)` using `replaceOpWithMultiple` so that
they "magically" appear in `adaptor.getOperands()`. Then at the end,
currently, these are the only unreconciled casts (because they are the
only ones **not** inserted by the conversion framework) and we
materialize them by just replacing uses of `%newarg` with `%arg`.
4. `scf.if` needs to be handled specially; since it has no operands but
can `yield` results, we need to rewrite it only after its `yield`s have
been rewritten. This is not straightforward because the dialect
conversion [does a preorder
walk](https://github.com/llvm/llvm-project/blob/6ab8401f53deddbd79f930ba2bec2f824c9567e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2705).
To work around this we define legality for `scf.if` to be dependent on
whether its `yield`s have been rewritten (using two `UnitAttr`s on those
`yield`s). Thus, `scf.if` is "legal" and not rewritten until after the
results of the `yield`s are known.


[^1]: I haven't actually moved it out of the flag but it's now usable
with the `AMDGCN_USE_BUFFER_OPS` flag whereas it wasn't prior.
[^2]: In reality it's the conversion framework that materializes this
tuple as `unrealized_cast(%baseptr, %offsetptr)` and then
reconciles/DCEs all the casts automatically.
[^3]: The `unrealized_cast`s are completely "transparent" to the
patterns, see
[`ConversionPatternRewriterImpl::remapValues`](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1161).
AlexAUT pushed a commit to AlexAUT/triton that referenced this pull request Feb 6, 2025
…n-lang#5329)

### TL;DR (too long, didn't review)

This PR re-enables the `tritonamdgpu-canonicalize-pointers` pass[^1].
The PR is effectively a complete rewrite of the original pass, which
walked the AST and mutated IR in-place, using the new [`1:N` dialect
conversion framework](llvm/llvm-project#116470).
Recall a "fat pointer" is a tuple-like `(%baseptr, %offsetptr)` - the
current (original) pass keeps this tuple in a global data structure
while the new/rewritten pass emits this tuple into the IR as an
`unrealized_cast(%baseptr, %offsetptr)`[^2].

Note, this PR also rewrites the existing lit test (see [this comment
below](triton-lang#5329 (comment))).

### Pass outline

The pass structure/action is roughly:

1. Perform an approximate sparse dataflow analysis to find all
transitive uses for `tt.func` args that are `tt.ptr`s; legalize only
these ops;
2. Rewrite all operations' `use`s and `result`s to be `(%baseptr,
%offsetptr)` using `ConversionPattern`s that takes the new
`OneToNOpAdaptor`, which automatically forwards both `%baseptr` and
`%offsetptr` through `adaptor.getOperands()`[^3];
3. Clean up remaining `unrealized_casts` (currently only handling one
category of such remaining casts but can be extended to handle all; see
bullet 1 in TODOs).

### Some pre-emptive call outs

Right up front I'll say this took a long time to figure out because
**a)** the conversion framework is hugely complex **b)** it's being
currently rewritten to be more robust/stable. As a consequence, the
implementation is complex but I've tried hard to **a)** simplify as much
as possible **b)** comment/note subtleties **c)** put in ample `assert`s
and checks to clarify intent and gracefully fail. So some things to call
out:

1. I called the dataflow analysis approximate because it does not
actually use
[DataFlow/SparseAnalysis](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp)
and instead computes a forward slice using the heuristic "transfer
function" that users of an op with a `tt.ptr` operand should be
rewritten. This heuristic works because the forward slice starts from
`tt.ptr` args on a `tt.func` and ends at `tt.store`, which has no
results. Note, there's no reason why this component of the pass can't be
a true `SparseAnalysis` implementation, it's just that this rewrite has
already taken way longer than I expected (so I leave that for a possible
follow-up).
2. The pass uses no global `TypeConverter` but uses local
`TypeConverter`s, in `BranchInterface`/`RegionInterface` patterns. This
is because **a)** we are not actually converting operand/result types
(we are converting number of operands/results) **b)** the conversion
framework expects/handles this lack of a `TypeConverter` exactly [the
way we
want](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1179-L1185).
The local type converters are used for ops that couple to basic blocks
(`^bb`s) that need to have their signatures rewritten (i.e., the ops for
which we need to do `rewriter.applySignatureConversion(block,
*conversion, &localTypeConverter)`). That's `scf.for`, `scf.while`,
`cf.br` and `cf.cond_br` (not needed for `scf.if` which has no `bb`
args).
3. `tt.func` is handled differently from all of the other ops - it is
not rewritten at all. Instead, for every `%arg: tt.ptr` arg, we insert
into the new body `%c0 = arith.constant 0 : i32` and `%newarg =
unrealized_cast(%arg, %c0) : tt.ptr` (manually, not done by the
conversion framework) and replace all uses of `%arg` by `%newarg`. These
are then unpacked to `(%arg, %c0)` using `replaceOpWithMultiple` so that
they "magically" appear in `adaptor.getOperands()`. Then at the end,
currently, these are the only unreconciled casts (because they are the
only ones **not** inserted by the conversion framework) and we
materialize them by just replacing uses of `%newarg` with `%arg`.
4. `scf.if` needs to be handled specially; since it has no operands but
can `yield` results, we need to rewrite it only after its `yield`s have
been rewritten. This is not straightforward because the dialect
conversion [does a preorder
walk](https://github.com/llvm/llvm-project/blob/6ab8401f53deddbd79f930ba2bec2f824c9567e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2705).
To work around this we define legality for `scf.if` to be dependent on
whether its `yield`s have been rewritten (using two `UnitAttr`s on those
`yield`s). Thus, `scf.if` is "legal" and not rewritten until after the
results of the `yield`s are known.


[^1]: I haven't actually moved it out of the flag but it's now usable
with the `AMDGCN_USE_BUFFER_OPS` flag whereas it wasn't prior.
[^2]: In reality it's the conversion framework that materializes this
tuple as `unrealized_cast(%baseptr, %offsetptr)` and then
reconciles/DCEs all the casts automatically.
[^3]: The `unrealized_cast`s are completely "transparent" to the
patterns, see
[`ConversionPatternRewriterImpl::remapValues`](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1161).
makslevental added a commit to makslevental/triton that referenced this pull request Feb 19, 2025
…n-lang#5329)

### TL;DR (too long, didn't review)

This PR re-enables the `tritonamdgpu-canonicalize-pointers` pass[^1].
The PR is effectively a complete rewrite of the original pass, which
walked the AST and mutated IR in-place, using the new [`1:N` dialect
conversion framework](llvm/llvm-project#116470).
Recall a "fat pointer" is a tuple-like `(%baseptr, %offsetptr)` - the
current (original) pass keeps this tuple in a global data structure
while the new/rewritten pass emits this tuple into the IR as an
`unrealized_cast(%baseptr, %offsetptr)`[^2].

Note, this PR also rewrites the existing lit test (see [this comment
below](triton-lang#5329 (comment))).

### Pass outline

The pass structure/action is roughly:

1. Perform an approximate sparse dataflow analysis to find all
transitive uses for `tt.func` args that are `tt.ptr`s; legalize only
these ops;
2. Rewrite all operations' `use`s and `result`s to be `(%baseptr,
%offsetptr)` using `ConversionPattern`s that takes the new
`OneToNOpAdaptor`, which automatically forwards both `%baseptr` and
`%offsetptr` through `adaptor.getOperands()`[^3];
3. Clean up remaining `unrealized_casts` (currently only handling one
category of such remaining casts but can be extended to handle all; see
bullet 1 in TODOs).

### Some pre-emptive call outs

Right up front I'll say this took a long time to figure out because
**a)** the conversion framework is hugely complex **b)** it's being
currently rewritten to be more robust/stable. As a consequence, the
implementation is complex but I've tried hard to **a)** simplify as much
as possible **b)** comment/note subtleties **c)** put in ample `assert`s
and checks to clarify intent and gracefully fail. So some things to call
out:

1. I called the dataflow analysis approximate because it does not
actually use
[DataFlow/SparseAnalysis](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp)
and instead computes a forward slice using the heuristic "transfer
function" that users of an op with a `tt.ptr` operand should be
rewritten. This heuristic works because the forward slice starts from
`tt.ptr` args on a `tt.func` and ends at `tt.store`, which has no
results. Note, there's no reason why this component of the pass can't be
a true `SparseAnalysis` implementation, it's just that this rewrite has
already taken way longer than I expected (so I leave that for a possible
follow-up).
2. The pass uses no global `TypeConverter` but uses local
`TypeConverter`s, in `BranchInterface`/`RegionInterface` patterns. This
is because **a)** we are not actually converting operand/result types
(we are converting number of operands/results) **b)** the conversion
framework expects/handles this lack of a `TypeConverter` exactly [the
way we
want](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1179-L1185).
The local type converters are used for ops that couple to basic blocks
(`^bb`s) that need to have their signatures rewritten (i.e., the ops for
which we need to do `rewriter.applySignatureConversion(block,
*conversion, &localTypeConverter)`). That's `scf.for`, `scf.while`,
`cf.br` and `cf.cond_br` (not needed for `scf.if` which has no `bb`
args).
3. `tt.func` is handled differently from all of the other ops - it is
not rewritten at all. Instead, for every `%arg: tt.ptr` arg, we insert
into the new body `%c0 = arith.constant 0 : i32` and `%newarg =
unrealized_cast(%arg, %c0) : tt.ptr` (manually, not done by the
conversion framework) and replace all uses of `%arg` by `%newarg`. These
are then unpacked to `(%arg, %c0)` using `replaceOpWithMultiple` so that
they "magically" appear in `adaptor.getOperands()`. Then at the end,
currently, these are the only unreconciled casts (because they are the
only ones **not** inserted by the conversion framework) and we
materialize them by just replacing uses of `%newarg` with `%arg`.
4. `scf.if` needs to be handled specially; since it has no operands but
can `yield` results, we need to rewrite it only after its `yield`s have
been rewritten. This is not straightforward because the dialect
conversion [does a preorder
walk](https://github.com/llvm/llvm-project/blob/6ab8401f53deddbd79f930ba2bec2f824c9567e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2705).
To work around this we define legality for `scf.if` to be dependent on
whether its `yield`s have been rewritten (using two `UnitAttr`s on those
`yield`s). Thus, `scf.if` is "legal" and not rewritten until after the
results of the `yield`s are known.


[^1]: I haven't actually moved it out of the flag but it's now usable
with the `AMDGCN_USE_BUFFER_OPS` flag whereas it wasn't prior.
[^2]: In reality it's the conversion framework that materializes this
tuple as `unrealized_cast(%baseptr, %offsetptr)` and then
reconciles/DCEs all the casts automatically.
[^3]: The `unrealized_cast`s are completely "transparent" to the
patterns, see
[`ConversionPatternRewriterImpl::remapValues`](https://github.com/llvm/llvm-project/blob/399c3a78a2577c6fc68bba7f301901a0e66e87ed/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1161).
AaronStGeorge added a commit to llvm/torch-mlir that referenced this pull request Feb 28, 2025
This PR updates AdjustCallingConventionsPass to the dialect conversion
framework API updates introduced in
llvm/llvm-project#116470. This may not be an
optimal use of the new API, but it is functional. Suggestions welcome!

fixes #3983
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:func mlir:llvm mlir:scf mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants