Skip to content

Commit f7019ea

Browse files
committed
[mlir][intrange] Fix arith.shl inference in case of overflow
When an overflow happens during shift left, i.e. the last sign bit or the most significant data bit gets shifted out, the current approach of inferring the range of results does not work anymore. This patch checks for possible overflow and returns the max range in that case. Fix #82158
1 parent febd89c commit f7019ea

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,18 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
548548
const APInt &r) -> std::optional<APInt> {
549549
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
550550
};
551+
552+
// The minMax inference does not work when there is danger of overflow. In the
553+
// signed case, this leads to the obvious problem that the sign bit might
554+
// change. In the unsigned case, it also leads to problems because the largest
555+
// LHS shifted by the largest RHS does not necessarily result in the largest
556+
// result anymore.
557+
bool signbitSafe =
558+
(lhs.smin().getNumSignBits() > rhs.umax().getZExtValue()) &&
559+
(lhs.smax().getNumSignBits() > rhs.umax().getZExtValue());
560+
if (!signbitSafe)
561+
return ConstantIntRanges::maxRange(lhs.umax().getBitWidth());
562+
551563
ConstantIntRanges urange =
552564
minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
553565
/*isSigned=*/false);

mlir/test/Dialect/Arith/int-range-opts.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,32 @@ func.func @test() -> i1 {
7171
%1 = arith.cmpi sle, %0, %cst1 : index
7272
return %1: i1
7373
}
74+
75+
// -----
76+
77+
// CHECK-LABEL: func @test
78+
// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
79+
func.func @test() -> index {
80+
%cst1 = arith.constant 1 : i8
81+
%0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
82+
%i8val = arith.index_cast %0 : index to i8
83+
%shifted = arith.shli %i8val, %cst1 : i8
84+
%si = arith.index_cast %shifted : i8 to index
85+
%1 = test.reflect_bounds %si
86+
return %1: index
87+
}
88+
89+
// -----
90+
91+
// CHECK-LABEL: func @test
92+
// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
93+
func.func @test() -> index {
94+
%cst1 = arith.constant 1 : i8
95+
%0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
96+
%i8val = arith.index_cast %0 : index to i8
97+
%shifted = arith.shli %i8val, %cst1 : i8
98+
%si = arith.index_cast %shifted : i8 to index
99+
%1 = test.reflect_bounds %si
100+
return %1: index
101+
}
102+

0 commit comments

Comments
 (0)