-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][arith] Fix canon pattern for large ints in chained arith #68900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Rik Huijzer (rikhuijzer) ChangesThe logic for chained basic arithmetic operations in the
According to a comment on llvm-project/mlir/include/mlir/IR/BuiltinAttributes.td Lines 707 to 708 in ab6a66d
This patch fixes #64774 by doing such a replacement. Full diff: https://github.com/llvm/llvm-project/pull/68900.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0ecc288f3b07701..25578b1c52f331b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -39,26 +39,35 @@ using namespace mlir::arith;
static IntegerAttr
applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
Attribute rhs,
- function_ref<int64_t(int64_t, int64_t)> binFn) {
- return builder.getIntegerAttr(res.getType(),
- binFn(llvm::cast<IntegerAttr>(lhs).getInt(),
- llvm::cast<IntegerAttr>(rhs).getInt()));
+ function_ref<APInt(APInt, APInt&)> binFn) {
+ auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
+ auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
+ auto value = binFn(lhsVal, rhsVal);
+ return IntegerAttr::get(res.getType(), value);
}
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>());
+ auto binFn = [](APInt a, APInt& b) -> APInt {
+ return std::move(a) + b;
+ };
+ return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>());
+ auto binFn = [](APInt a, APInt& b) -> APInt {
+ return std::move(a) - b;
+ };
+ return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return applyToIntegerAttrs(builder, res, lhs, rhs,
- std::multiplies<int64_t>());
+ auto binFn = [](APInt a, APInt& b) -> APInt {
+ return std::move(a) * b;
+ };
+ return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
/// Invert an integer comparison predicate.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1b0547c9e8f804a..b18f5cfcb3f9a12 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -985,6 +985,16 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
return %mul2 : i32
}
+// CHECK-LABEL: @tripleMulLargeInt
+// CHECK: return
+func.func @tripleMulLargeInt(%arg0: i256) -> i256 {
+ %0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256
+ %c5 = arith.constant 5 : i256
+ %mul1 = arith.muli %arg0, %0 : i256
+ %mul2 = arith.muli %mul1, %c5 : i256
+ return %mul2 : i256
+}
+
// CHECK-LABEL: @addiMuliToSubiRhsI32
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
I’ll fix the formatting issue tomorrow. Question to anyone who reviews this. Should we remove the added test? It’s a trivial test and tests are typically not removed which will lead to long test times over time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for tackling this! I only left some comments regarding some code details.
I also think that you should definitely keep the test. Given that this case used to assert, we seemingly didn't have any test coverage for larger than 64 bit integers, which is why having the test is definitely valuable.
Co-authored-by: Markus Böck <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The logic for chained basic arithmetic operations in the
arith
dialect was usinggetInt()
onIntegerAttr
. This is a problem for very large integers. Specifically, in #64774 the following assertion failed:According to a comment on
getInt()
, calls togetInt()
should be replaced bygetValue()
:llvm-project/mlir/include/mlir/IR/BuiltinAttributes.td
Lines 707 to 708 in ab6a66d
This patch fixes #64774 by doing such a replacement.