Skip to content

Commit 1a8c613

Browse files
authored
[mlir][arith] Align shift Ops with LLVM instructions on allowed shift amounts (#82133)
This patch aligns the shift Ops in `arith` with respective LLVM instructions. Specifically, shifting by an amount equal to the bitwidth of the operand is now defined to return poison. Relevant discussion: https://discourse.llvm.org/t/some-question-on-the-semantics-of-the-arith-dialect/74861/10 Relevant issue: #80960
1 parent 833fea4 commit 1a8c613

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
788788
The `shli` operation shifts the integer value of the first operand to the left
789789
by the integer value of the second operand. The second operand is interpreted as
790790
unsigned. The low order bits are filled with zeros. If the value of the second
791-
operand is greater than the bitwidth of the first operand, then the
791+
operand is greater or equal than the bitwidth of the first operand, then the
792792
operation returns poison.
793793

794794
This op supports `nuw`/`nsw` overflow flags which stands stand for
@@ -818,8 +818,8 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
818818
The `shrui` operation shifts an integer value of the first operand to the right
819819
by the value of the second operand. The first operand is interpreted as unsigned,
820820
and the second operand is interpreted as unsigned. The high order bits are always
821-
filled with zeros. If the value of the second operand is greater than the bitwidth
822-
of the first operand, then the operation returns poison.
821+
filled with zeros. If the value of the second operand is greater or equal than the
822+
bitwidth of the first operand, then the operation returns poison.
823823

824824
Example:
825825

@@ -844,8 +844,8 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
844844
and the second operand is interpreter as unsigned. The high order bits in the
845845
output are filled with copies of the most-significant bit of the shifted value
846846
(which means that the sign of the value is preserved). If the value of the second
847-
operand is greater than bitwidth of the first operand, then the operation returns
848-
poison.
847+
operand is greater or equal than bitwidth of the first operand, then the operation
848+
returns poison.
849849

850850
Example:
851851

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2379,11 +2379,11 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
23792379
// shli(x, 0) -> x
23802380
if (matchPattern(adaptor.getRhs(), m_Zero()))
23812381
return getLhs();
2382-
// Don't fold if shifting more than the bit width.
2382+
// Don't fold if shifting more or equal than the bit width.
23832383
bool bounded = false;
23842384
auto result = constFoldBinaryOp<IntegerAttr>(
23852385
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2386-
bounded = b.ule(b.getBitWidth());
2386+
bounded = b.ult(b.getBitWidth());
23872387
return a.shl(b);
23882388
});
23892389
return bounded ? result : Attribute();
@@ -2397,11 +2397,11 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
23972397
// shrui(x, 0) -> x
23982398
if (matchPattern(adaptor.getRhs(), m_Zero()))
23992399
return getLhs();
2400-
// Don't fold if shifting more than the bit width.
2400+
// Don't fold if shifting more or equal than the bit width.
24012401
bool bounded = false;
24022402
auto result = constFoldBinaryOp<IntegerAttr>(
24032403
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2404-
bounded = b.ule(b.getBitWidth());
2404+
bounded = b.ult(b.getBitWidth());
24052405
return a.lshr(b);
24062406
});
24072407
return bounded ? result : Attribute();
@@ -2415,11 +2415,11 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
24152415
// shrsi(x, 0) -> x
24162416
if (matchPattern(adaptor.getRhs(), m_Zero()))
24172417
return getLhs();
2418-
// Don't fold if shifting more than the bit width.
2418+
// Don't fold if shifting more or equal than the bit width.
24192419
bool bounded = false;
24202420
auto result = constFoldBinaryOp<IntegerAttr>(
24212421
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2422-
bounded = b.ule(b.getBitWidth());
2422+
bounded = b.ult(b.getBitWidth());
24232423
return a.ashr(b);
24242424
});
24252425
return bounded ? result : Attribute();

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2179,6 +2179,17 @@ func.func @nofoldShl2() -> i64 {
21792179
return %r : i64
21802180
}
21812181

2182+
// CHECK-LABEL: @nofoldShl3(
2183+
// CHECK: %[[res:.+]] = arith.shli
2184+
// CHECK: return %[[res]]
2185+
func.func @nofoldShl3() -> i64 {
2186+
%c1 = arith.constant 1 : i64
2187+
%c64 = arith.constant 64 : i64
2188+
// Note that this should return Poison in the future.
2189+
%r = arith.shli %c1, %c64 : i64
2190+
return %r : i64
2191+
}
2192+
21822193
// CHECK-LABEL: @foldShru(
21832194
// CHECK: %[[res:.+]] = arith.constant 2 : i64
21842195
// CHECK: return %[[res]]
@@ -2219,6 +2230,17 @@ func.func @nofoldShru2() -> i64 {
22192230
return %r : i64
22202231
}
22212232

2233+
// CHECK-LABEL: @nofoldShru3(
2234+
// CHECK: %[[res:.+]] = arith.shrui
2235+
// CHECK: return %[[res]]
2236+
func.func @nofoldShru3() -> i64 {
2237+
%c1 = arith.constant 8 : i64
2238+
%c64 = arith.constant 64 : i64
2239+
// Note that this should return Poison in the future.
2240+
%r = arith.shrui %c1, %c64 : i64
2241+
return %r : i64
2242+
}
2243+
22222244
// CHECK-LABEL: @foldShrs(
22232245
// CHECK: %[[res:.+]] = arith.constant 2 : i64
22242246
// CHECK: return %[[res]]
@@ -2259,6 +2281,17 @@ func.func @nofoldShrs2() -> i64 {
22592281
return %r : i64
22602282
}
22612283

2284+
// CHECK-LABEL: @nofoldShrs3(
2285+
// CHECK: %[[res:.+]] = arith.shrsi
2286+
// CHECK: return %[[res]]
2287+
func.func @nofoldShrs3() -> i64 {
2288+
%c1 = arith.constant 8 : i64
2289+
%c64 = arith.constant 64 : i64
2290+
// Note that this should return Poison in the future.
2291+
%r = arith.shrsi %c1, %c64 : i64
2292+
return %r : i64
2293+
}
2294+
22622295
// -----
22632296

22642297
// CHECK-LABEL: @test_negf(

0 commit comments

Comments
 (0)