Skip to content

Commit 0f79066

Browse files
authored
[mlir][intrange] Fix arith.shl inference in case of overflow (#91737)
When an overflow happens during shift left, i.e. the last sign bit or the most significant data bit gets shifted out, the current approach of inferring the range of results does not work anymore. This patch checks for possible overflow and returns the max range in that case. Fix #82158
1 parent cf40c93 commit 0f79066

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,15 +544,30 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
544544
ConstantIntRanges
545545
mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
546546
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
547+
const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
548+
&lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
549+
&rhsUMax = rhs.umax();
550+
547551
ConstArithFn shl = [](const APInt &l,
548552
const APInt &r) -> std::optional<APInt> {
549553
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
550554
};
555+
556+
// The minMax inference does not work when there is danger of overflow. In the
557+
// signed case, this leads to the obvious problem that the sign bit might
558+
// change. In the unsigned case, it also leads to problems because the largest
559+
// LHS shifted by the largest RHS does not necessarily result in the largest
560+
// result anymore.
561+
assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
562+
if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
563+
rhsUMax.uge(lhsSMax.getNumSignBits()))
564+
return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
565+
551566
ConstantIntRanges urange =
552-
minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
567+
minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
553568
/*isSigned=*/false);
554569
ConstantIntRanges srange =
555-
minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
570+
minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
556571
/*isSigned=*/true);
557572
return urange.intersection(srange);
558573
}

mlir/test/Dialect/Arith/int-range-opts.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,32 @@ func.func @test() -> i1 {
7171
%1 = arith.cmpi sle, %0, %cst1 : index
7272
return %1: i1
7373
}
74+
75+
// -----
76+
77+
// CHECK-LABEL: func @test
78+
// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
79+
func.func @test() -> index {
80+
%cst1 = arith.constant 1 : i8
81+
%0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
82+
%i8val = arith.index_cast %0 : index to i8
83+
%shifted = arith.shli %i8val, %cst1 : i8
84+
%si = arith.index_cast %shifted : i8 to index
85+
%1 = test.reflect_bounds %si
86+
return %1: index
87+
}
88+
89+
// -----
90+
91+
// CHECK-LABEL: func @test
92+
// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
93+
func.func @test() -> index {
94+
%cst1 = arith.constant 1 : i8
95+
%0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
96+
%i8val = arith.index_cast %0 : index to i8
97+
%shifted = arith.shli %i8val, %cst1 : i8
98+
%si = arith.index_cast %shifted : i8 to index
99+
%1 = test.reflect_bounds %si
100+
return %1: index
101+
}
102+

0 commit comments

Comments
 (0)