Skip to content

[mlir] fix infinite while in 1:N dialect conversion #123122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

makslevental
Copy link
Contributor

#116524 introduced a "subtle" bug; I can't give a repro unless you're willing to build triton-lang/triton#5329 but here's a sketch:

tt.func @conversion1(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
    ... 
}
class ConvertFuncOp : public OpConversionPattern<tt::FuncOp> {
public:
  using PointerCanonicalizationPattern::PointerCanonicalizationPattern;

  LogicalResult
  matchAndRewrite_(tt::FuncOp funcOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
    ...
    // REPLACE tt.ptr WITH (tt.ptr, arith.constant 0) SAME tt.ptr!!!!
    ...
    return success();
  }
};

Then in

// mlir/lib/Transforms/Utils/DialectConversion.cpp
ConversionPatternRewriterImpl::remapValues(...) {
    if (!currentTypeConverter) {
      // The current pattern does not have a type converter. I.e., it does not
      // distinguish between legal and illegal types. For each operand, simply
      // pass through the most recently mapped values.
      remapped.push_back(mapping.lookupOrDefault(operand));
      continue;
    }
}

Then

ConversionValueMapping::lookupOrDefault(...) {
  ...
  do {
    ValueVector next;
    for (Value v : current) {
      // ALWAYS FINDS tt.ptr AT INDEX 0
      auto it = mapping.find({v});
      if (it != mapping.end()) {
        // ALWAYS REPLACES tt.ptr AT INDEX 0 
        // WITH (tt.ptr, arith.constant)
        llvm::append_range(next, it->second);
      } else {
        next.push_back(v);
      }
      ...
    }
  } while (true);

result: current grows without bounds, looks like (tt.ptr, arith.constant ,arith.constant, arith.constant....) and while loops forever.

The fix is to check whether we're "deepening" on the same Value.

Note, I can't add a test for this because the fail is an infinite loop.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jan 15, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 15, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

#116524 introduced a "subtle" bug; I can't give a repro unless you're willing to build triton-lang/triton#5329 but here's a sketch:

tt.func @<!-- -->conversion1(%arg0: !tt.ptr&lt;f32&gt;) -&gt; tensor&lt;1024xf32&gt; {
    ... 
}
class ConvertFuncOp : public OpConversionPattern&lt;tt::FuncOp&gt; {
public:
  using PointerCanonicalizationPattern::PointerCanonicalizationPattern;

  LogicalResult
  matchAndRewrite_(tt::FuncOp funcOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &amp;rewriter) const override {
    ...
    // REPLACE tt.ptr WITH (tt.ptr, arith.constant 0) SAME tt.ptr!!!!
    ...
    return success();
  }
};

Then in

// mlir/lib/Transforms/Utils/DialectConversion.cpp
ConversionPatternRewriterImpl::remapValues(...) {
    if (!currentTypeConverter) {
      // The current pattern does not have a type converter. I.e., it does not
      // distinguish between legal and illegal types. For each operand, simply
      // pass through the most recently mapped values.
      remapped.push_back(mapping.lookupOrDefault(operand));
      continue;
    }
}

Then

ConversionValueMapping::lookupOrDefault(...) {
  ...
  do {
    ValueVector next;
    for (Value v : current) {
      // ALWAYS FINDS tt.ptr AT INDEX 0
      auto it = mapping.find({v});
      if (it != mapping.end()) {
        // ALWAYS REPLACES tt.ptr AT INDEX 0 
        // WITH (tt.ptr, arith.constant)
        llvm::append_range(next, it-&gt;second);
      } else {
        next.push_back(v);
      }
      ...
    }
  } while (true);

result: current grows without bounds, looks like (tt.ptr, arith.constant ,arith.constant, arith.constant....) and while loops forever.

The fix is to check whether we're "deepening" on the same Value.

Note, I can't add a test for this because the fail is an infinite loop.


Full diff: https://github.com/llvm/llvm-project/pull/123122.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+3-2)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 403321d40d53c9..83d66dbe342d3f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -201,8 +201,9 @@ ConversionValueMapping::lookupOrDefault(Value from,
     // If possible, Replace each value with (one or multiple) mapped values.
     ValueVector next;
     for (Value v : current) {
-      auto it = mapping.find({v});
-      if (it != mapping.end()) {
+      ValueVector vv{v};
+      auto it = mapping.find(vv);
+      if (it != mapping.end() && it->first != vv) {
         llvm::append_range(next, it->second);
       } else {
         next.push_back(v);

@matthias-springer
Copy link
Member

Did you close this PR on purpose?

@matthias-springer
Copy link
Member

    // REPLACE tt.ptr WITH (tt.ptr, arith.constant 0) SAME tt.ptr!!!!

Can you show the C++ code for that?

There are two 2 ways to replace a value: replaceOp[WithMultiple] and applySignatureConversion. Both erase the old op / block, so the nested tt.ptr should be a different one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants