Skip to content

[mlir][arith] Align shift Ops with LLVM instructions on allowed shift amounts #82133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
The `shli` operation shifts the integer value of the first operand to the left
by the integer value of the second operand. The second operand is interpreted as
unsigned. The low order bits are filled with zeros. If the value of the second
operand is greater than the bitwidth of the first operand, then the
operand is greater or equal than the bitwidth of the first operand, then the
operation returns poison.

This op supports `nuw`/`nsw` overflow flags which stands stand for
Expand Down Expand Up @@ -818,8 +818,8 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
The `shrui` operation shifts an integer value of the first operand to the right
by the value of the second operand. The first operand is interpreted as unsigned,
and the second operand is interpreted as unsigned. The high order bits are always
filled with zeros. If the value of the second operand is greater than the bitwidth
of the first operand, then the operation returns poison.
filled with zeros. If the value of the second operand is greater or equal than the
bitwidth of the first operand, then the operation returns poison.

Example:

Expand All @@ -844,8 +844,8 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
and the second operand is interpreter as unsigned. The high order bits in the
output are filled with copies of the most-significant bit of the shifted value
(which means that the sign of the value is preserved). If the value of the second
operand is greater than bitwidth of the first operand, then the operation returns
poison.
operand is greater or equal than bitwidth of the first operand, then the operation
returns poison.

Example:

Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2379,11 +2379,11 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
// shli(x, 0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
// Don't fold if shifting more or equal than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
bounded = b.ule(b.getBitWidth());
bounded = b.ult(b.getBitWidth());
return a.shl(b);
});
return bounded ? result : Attribute();
Expand All @@ -2397,11 +2397,11 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
// shrui(x, 0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
// Don't fold if shifting more or equal than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
bounded = b.ule(b.getBitWidth());
bounded = b.ult(b.getBitWidth());
return a.lshr(b);
});
return bounded ? result : Attribute();
Expand All @@ -2415,11 +2415,11 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
// shrsi(x, 0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
// Don't fold if shifting more or equal than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
bounded = b.ule(b.getBitWidth());
bounded = b.ult(b.getBitWidth());
return a.ashr(b);
});
return bounded ? result : Attribute();
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2179,6 +2179,17 @@ func.func @nofoldShl2() -> i64 {
return %r : i64
}

// CHECK-LABEL: @nofoldShl3(
// CHECK: %[[res:.+]] = arith.shli
// CHECK: return %[[res]]
func.func @nofoldShl3() -> i64 {
%c1 = arith.constant 1 : i64
%c64 = arith.constant 64 : i64
// Note that this should return Poison in the future.
%r = arith.shli %c1, %c64 : i64
return %r : i64
}

// CHECK-LABEL: @foldShru(
// CHECK: %[[res:.+]] = arith.constant 2 : i64
// CHECK: return %[[res]]
Expand Down Expand Up @@ -2219,6 +2230,17 @@ func.func @nofoldShru2() -> i64 {
return %r : i64
}

// CHECK-LABEL: @nofoldShru3(
// CHECK: %[[res:.+]] = arith.shrui
// CHECK: return %[[res]]
func.func @nofoldShru3() -> i64 {
%c1 = arith.constant 8 : i64
%c64 = arith.constant 64 : i64
// Note that this should return Poison in the future.
%r = arith.shrui %c1, %c64 : i64
return %r : i64
}

// CHECK-LABEL: @foldShrs(
// CHECK: %[[res:.+]] = arith.constant 2 : i64
// CHECK: return %[[res]]
Expand Down Expand Up @@ -2259,6 +2281,17 @@ func.func @nofoldShrs2() -> i64 {
return %r : i64
}

// CHECK-LABEL: @nofoldShrs3(
// CHECK: %[[res:.+]] = arith.shrsi
// CHECK: return %[[res]]
func.func @nofoldShrs3() -> i64 {
%c1 = arith.constant 8 : i64
%c64 = arith.constant 64 : i64
// Note that this should return Poison in the future.
%r = arith.shrsi %c1, %c64 : i64
return %r : i64
}

// -----

// CHECK-LABEL: @test_negf(
Expand Down