Skip to content

[mlir][arith] Fix canon pattern for large ints in chained arith #68900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 13, 2023
Merged

[mlir][arith] Fix canon pattern for large ints in chained arith #68900

merged 5 commits into from
Oct 13, 2023

Conversation

rikhuijzer
Copy link
Member

The logic for chained basic arithmetic operations in the arith dialect was using getInt() on IntegerAttr. This is a problem for very large integers. Specifically, in #64774 the following assertion failed:

Assertion failed: (getSignificantBits() <= 64 && "Too many bits for int64_t"), function getSExtValue, file APInt.h, line 1510.

According to a comment on getInt(), calls to getInt() should be replaced by getValue():

// TODO: Change callers to use getValue instead.
int64_t getInt() const;

This patch fixes #64774 by doing such a replacement.

@llvmbot
Copy link
Member

llvmbot commented Oct 12, 2023

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Rik Huijzer (rikhuijzer)

Changes

The logic for chained basic arithmetic operations in the arith dialect was using getInt() on IntegerAttr. This is a problem for very large integers. Specifically, in #64774 the following assertion failed:

Assertion failed: (getSignificantBits() &lt;= 64 &amp;&amp; "Too many bits for int64_t"), function getSExtValue, file APInt.h, line 1510.

According to a comment on getInt(), calls to getInt() should be replaced by getValue():

// TODO: Change callers to use getValue instead.
int64_t getInt() const;

This patch fixes #64774 by doing such a replacement.


Full diff: https://github.com/llvm/llvm-project/pull/68900.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+17-8)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+10)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0ecc288f3b07701..25578b1c52f331b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -39,26 +39,35 @@ using namespace mlir::arith;
 static IntegerAttr
 applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
                     Attribute rhs,
-                    function_ref<int64_t(int64_t, int64_t)> binFn) {
-  return builder.getIntegerAttr(res.getType(),
-                                binFn(llvm::cast<IntegerAttr>(lhs).getInt(),
-                                      llvm::cast<IntegerAttr>(rhs).getInt()));
+                    function_ref<APInt(APInt, APInt&)> binFn) {
+  auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
+  auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
+  auto value = binFn(lhsVal, rhsVal);
+  return IntegerAttr::get(res.getType(), value);
 }
 
 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>());
+  auto binFn = [](APInt a, APInt& b) -> APInt {
+    return std::move(a) + b;
+  };
+  return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
 }
 
 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>());
+  auto binFn = [](APInt a, APInt& b) -> APInt {
+    return std::move(a) - b;
+  };
+  return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
 }
 
 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return applyToIntegerAttrs(builder, res, lhs, rhs,
-                             std::multiplies<int64_t>());
+  auto binFn = [](APInt a, APInt& b) -> APInt {
+    return std::move(a) * b;
+  };
+  return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
 }
 
 /// Invert an integer comparison predicate.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1b0547c9e8f804a..b18f5cfcb3f9a12 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -985,6 +985,16 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
   return %mul2 : i32
 }
 
+// CHECK-LABEL: @tripleMulLargeInt
+//       CHECK:   return
+func.func @tripleMulLargeInt(%arg0: i256) -> i256 {
+  %0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256
+  %c5 = arith.constant 5 : i256
+  %mul1 = arith.muli %arg0, %0 : i256
+  %mul2 = arith.muli %mul1, %c5 : i256
+  return %mul2 : i256
+}
+
 // CHECK-LABEL: @addiMuliToSubiRhsI32
 //  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
 //       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32

@github-actions
Copy link

github-actions bot commented Oct 12, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@rikhuijzer
Copy link
Member Author

I’ll fix the formatting issue tomorrow.

Question to anyone who reviews this. Should we remove the added test? It’s a trivial test and tests are typically not removed which will lead to long test times over time.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for tackling this! I only left some comments regarding some code details.

I also think that you should definitely keep the test. Given that this case used to assert, we seemingly didn't have any test coverage for larger than 64 bit integers, which is why having the test is definitely valuable.

@rikhuijzer rikhuijzer requested a review from zero9178 October 13, 2023 08:50
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Flag --canonicalize is broken in mlir-opt with LLVM 16
3 participants