Skip to content

Commit 7ef1754

Browse files
rikhuijzerzero9178
andauthored
[mlir][arith] Fix canon pattern for large ints in chained arith (#68900)
The logic for chained basic arithmetic operations in the `arith` dialect was using `getInt()` on `IntegerAttr`. This is a problem for very large integers. Specifically, in #64774 the following assertion failed: ``` Assertion failed: (getSignificantBits() <= 64 && "Too many bits for int64_t"), function getSExtValue, file APInt.h, line 1510. ``` According to a comment on `getInt()`, calls to `getInt()` should be replaced by `getValue()`: https://github.com/llvm/llvm-project/blob/ab6a66dbec61654d0962f6abf6d6c5b776937584/mlir/include/mlir/IR/BuiltinAttributes.td#L707-L708 This patch fixes #64774 by doing such a replacement. --------- Co-authored-by: Markus Böck <[email protected]>
1 parent c6f065d commit 7ef1754

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,26 @@ using namespace mlir::arith;
3939
static IntegerAttr
4040
applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
4141
Attribute rhs,
42-
function_ref<int64_t(int64_t, int64_t)> binFn) {
43-
return builder.getIntegerAttr(res.getType(),
44-
binFn(llvm::cast<IntegerAttr>(lhs).getInt(),
45-
llvm::cast<IntegerAttr>(rhs).getInt()));
42+
function_ref<APInt(const APInt &, const APInt &)> binFn) {
43+
APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
44+
APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
45+
APInt value = binFn(lhsVal, rhsVal);
46+
return IntegerAttr::get(res.getType(), value);
4647
}
4748

4849
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
4950
Attribute lhs, Attribute rhs) {
50-
return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>());
51+
return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
5152
}
5253

5354
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
5455
Attribute lhs, Attribute rhs) {
55-
return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>());
56+
return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
5657
}
5758

5859
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
5960
Attribute lhs, Attribute rhs) {
60-
return applyToIntegerAttrs(builder, res, lhs, rhs,
61-
std::multiplies<int64_t>());
61+
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
6262
}
6363

6464
/// Invert an integer comparison predicate.

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,18 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
909909
return %mul2 : i32
910910
}
911911

912+
// CHECK-LABEL: @tripleMulLargeInt
913+
// CHECK: %[[cres:.+]] = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020482 : i256
914+
// CHECK: %[[addi:.+]] = arith.addi %arg0, %[[cres]] : i256
915+
// CHECK: return %[[addi]]
916+
func.func @tripleMulLargeInt(%arg0: i256) -> i256 {
917+
%0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256
918+
%1 = arith.constant 1 : i256
919+
%2 = arith.addi %arg0, %0 : i256
920+
%3 = arith.addi %2, %1 : i256
921+
return %3 : i256
922+
}
923+
912924
// CHECK-LABEL: @addiMuliToSubiRhsI32
913925
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
914926
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32

0 commit comments

Comments
 (0)