Skip to content

[mlir][intrange] Fix arith.shl inference in case of overflow #91737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 13, 2024

Conversation

ubfx
Copy link
Member

@ubfx ubfx commented May 10, 2024

When an overflow happens during shift left, i.e. the last sign bit or the most significant data bit gets shifted out, the current approach of inferring the range of results does not work anymore.

This patch checks for possible overflow and returns the max range in that case.

This bug shows up most clearly when we peform arith.shli on a full-range integer, so for example an i8 between 0 and 0xff:

// mlir-opt --int-range-optimizations
func.func @test() -> index {
  %cst1 = arith.constant 1 : i8
  %0 = test.with_bounds { umin = 0 : index, umax = 255 : index, smin = -128 : index, smax = 127 : index }
  %i8val = arith.index_cast %0 : index to i8
  %shifted = arith.shli %i8val, %cst1 : i8
  %si = arith.index_cast %shifted : i8 to index
  %1 = test.reflect_bounds %si
// gets optimized to 
// %4 = test.reflect_bounds {smax = 0 : index, smin = -2 : index, umax = -1 : index, umin = 0 : index} %3
  return %1: index
}

The result range will get vastly underestimated to [-2, 0] which leads to wrong optimizations.

Fix #82158

When an overflow happens during shift left, i.e. the last sign bit or
the most significant data bit gets shifted out, the current approach of
inferring the range of results does not work anymore.

This patch checks for possible overflow and returns the max range in that
case.

Fix llvm#82158
@llvmbot
Copy link
Member

llvmbot commented May 10, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Felix Schneider (ubfx)

Changes

When an overflow happens during shift left, i.e. the last sign bit or the most significant data bit gets shifted out, the current approach of inferring the range of results does not work anymore.

This patch checks for possible overflow and returns the max range in that case.

This bug shows up most clearly when we peform arith.shli on a full-range integer, so for example an i8 between 0 and 0xff:

// mlir-opt --int-range-optimizations
func.func @<!-- -->test() -&gt; index {
  %cst1 = arith.constant 1 : i8
  %0 = test.with_bounds { umin = 0 : index, umax = 255 : index, smin = -128 : index, smax = 127 : index }
  %i8val = arith.index_cast %0 : index to i8
  %shifted = arith.shli %i8val, %cst1 : i8
  %si = arith.index_cast %shifted : i8 to index
  %1 = test.reflect_bounds %si
// gets optimized to 
// %4 = test.reflect_bounds {smax = 0 : index, smin = -2 : index, umax = -1 : index, umin = 0 : index} %3
  return %1: index
}

The result range will get vastly underestimated to [-2, 0] which leads to wrong optimizations.

Fix #82158


Full diff: https://github.com/llvm/llvm-project/pull/91737.diff

2 Files Affected:

  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+12)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+29)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2b2d937d55d80..12fae495b761e 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -548,6 +548,18 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
                         const APInt &r) -> std::optional<APInt> {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
   };
+
+  // The minMax inference does not work when there is danger of overflow. In the
+  // signed case, this leads to the obvious problem that the sign bit might
+  // change. In the unsigned case, it also leads to problems because the largest
+  // LHS shifted by the largest RHS does not necessarily result in the largest
+  // result anymore.
+  bool signbitSafe =
+      (lhs.smin().getNumSignBits() > rhs.umax().getZExtValue()) &&
+      (lhs.smax().getNumSignBits() > rhs.umax().getZExtValue());
+  if (!signbitSafe)
+    return ConstantIntRanges::maxRange(lhs.umax().getBitWidth());
+
   ConstantIntRanges urange =
       minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
                /*isSigned=*/false);
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index be0a7e8ccd70b..4c3c0854ed026 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -71,3 +71,32 @@ func.func @test() -> i1 {
   %1 = arith.cmpi sle, %0, %cst1 : index
   return %1: i1
 }
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
+func.func @test() -> index {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
+  %i8val = arith.index_cast %0 : index to i8
+  %shifted = arith.shli %i8val, %cst1 : i8
+  %si = arith.index_cast %shifted : i8 to index
+  %1 = test.reflect_bounds %si
+  return %1: index
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
+func.func @test() -> index {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
+  %i8val = arith.index_cast %0 : index to i8
+  %shifted = arith.shli %i8val, %cst1 : i8
+  %si = arith.index_cast %shifted : i8 to index
+  %1 = test.reflect_bounds %si
+  return %1: index
+}
+

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good bugfix,. approved!

(Also, would you be open to plumbing in support for nsw and nuw flags while you're here? That is, inferring them and/or using them to improve inference precision?

I've had that on my vague list of things that'd be nice to do but haven't had the time)

@krzysz00
Copy link
Contributor

... wait, hold on, does that use of SExtValue()/ZExtValue() cause problems with, say, i128?

@ubfx
Copy link
Member Author

ubfx commented May 10, 2024

... wait, hold on, does that use of SExtValue()/ZExtValue() cause problems with, say, i128?

True, that might be a problem. I'll look into it.

@ubfx ubfx force-pushed the intrange-fix-shl branch from c1253e7 to bf81f62 Compare May 10, 2024 19:04
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

@ubfx ubfx merged commit 0f79066 into llvm:main May 13, 2024
4 checks passed
@ubfx ubfx deleted the intrange-fix-shl branch May 13, 2024 17:27
@joker-eph
Copy link
Collaborator

This broke the bot, seem like a missing dependency in CMake:

ld.lld: error: undefined symbol: mlir::vector::populateVectorInterleaveToShufflePatterns(mlir::RewritePatternSet&, mlir::PatternBenefit)

@ubfx
Copy link
Member Author

ubfx commented May 13, 2024

This broke the bot,

I think it was the commit before this one (cf40c93), which has been reverted? Or should I revert this too?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inconsistent results when using --arith-unsigned-when-equivalent --int-range-optimizations
5 participants