-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][intrange] Fix arith.shl
inference in case of overflow
#91737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When an overflow happens during shift left, i.e. the last sign bit or the most significant data bit gets shifted out, the current approach of inferring the range of results does not work anymore. This patch checks for possible overflow and returns the max range in that case. Fix llvm#82158
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Felix Schneider (ubfx) ChangesWhen an overflow happens during shift left, i.e. the last sign bit or the most significant data bit gets shifted out, the current approach of inferring the range of results does not work anymore. This patch checks for possible overflow and returns the max range in that case. This bug shows up most clearly when we peform // mlir-opt --int-range-optimizations
func.func @<!-- -->test() -> index {
%cst1 = arith.constant 1 : i8
%0 = test.with_bounds { umin = 0 : index, umax = 255 : index, smin = -128 : index, smax = 127 : index }
%i8val = arith.index_cast %0 : index to i8
%shifted = arith.shli %i8val, %cst1 : i8
%si = arith.index_cast %shifted : i8 to index
%1 = test.reflect_bounds %si
// gets optimized to
// %4 = test.reflect_bounds {smax = 0 : index, smin = -2 : index, umax = -1 : index, umin = 0 : index} %3
return %1: index
} The result range will get vastly underestimated to [-2, 0] which leads to wrong optimizations. Fix #82158 Full diff: https://github.com/llvm/llvm-project/pull/91737.diff 2 Files Affected:
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2b2d937d55d80..12fae495b761e 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -548,6 +548,18 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
const APInt &r) -> std::optional<APInt> {
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
};
+
+ // The minMax inference does not work when there is danger of overflow. In the
+ // signed case, this leads to the obvious problem that the sign bit might
+ // change. In the unsigned case, it also leads to problems because the largest
+ // LHS shifted by the largest RHS does not necessarily result in the largest
+ // result anymore.
+ bool signbitSafe =
+ (lhs.smin().getNumSignBits() > rhs.umax().getZExtValue()) &&
+ (lhs.smax().getNumSignBits() > rhs.umax().getZExtValue());
+ if (!signbitSafe)
+ return ConstantIntRanges::maxRange(lhs.umax().getBitWidth());
+
ConstantIntRanges urange =
minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
/*isSigned=*/false);
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index be0a7e8ccd70b..4c3c0854ed026 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -71,3 +71,32 @@ func.func @test() -> i1 {
%1 = arith.cmpi sle, %0, %cst1 : index
return %1: i1
}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
+func.func @test() -> index {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
+ %i8val = arith.index_cast %0 : index to i8
+ %shifted = arith.shli %i8val, %cst1 : i8
+ %si = arith.index_cast %shifted : i8 to index
+ %1 = test.reflect_bounds %si
+ return %1: index
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
+func.func @test() -> index {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
+ %i8val = arith.index_cast %0 : index to i8
+ %shifted = arith.shli %i8val, %cst1 : i8
+ %si = arith.index_cast %shifted : i8 to index
+ %1 = test.reflect_bounds %si
+ return %1: index
+}
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good bugfix,. approved!
(Also, would you be open to plumbing in support for nsw
and nuw
flags while you're here? That is, inferring them and/or using them to improve inference precision?
I've had that on my vague list of things that'd be nice to do but haven't had the time)
... wait, hold on, does that use of SExtValue()/ZExtValue() cause problems with, say, |
True, that might be a problem. I'll look into it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved
This broke the bot, seem like a missing dependency in CMake:
|
I think it was the commit before this one (cf40c93), which has been reverted? Or should I revert this too? |
When an overflow happens during shift left, i.e. the last sign bit or the most significant data bit gets shifted out, the current approach of inferring the range of results does not work anymore.
This patch checks for possible overflow and returns the max range in that case.
This bug shows up most clearly when we peform
arith.shli
on a full-range integer, so for example an i8 between 0 and 0xff:The result range will get vastly underestimated to [-2, 0] which leads to wrong optimizations.
Fix #82158