Skip to content

Commit afa178d

Browse files
authored
[mlir][LLVM] Add exact flag (#115327)
The implementation is mostly based on the one existing for the nsw and nuw flags. If the exact flag is present, the corresponding operation returns a poison value when the result is not exact. (For a division, if rounding happens; for a right shift, if a non-zero bit is shifted out.)
1 parent 724b432 commit afa178d

File tree

8 files changed

+101
-4
lines changed

8 files changed

+101
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,33 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface">
8787
];
8888
}
8989

90+
def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
91+
let description = [{
92+
This interface defines an LLVM operation with an exact flag and
93+
provides a uniform API for accessing it.
94+
}];
95+
96+
let cppNamespace = "::mlir::LLVM";
97+
98+
let methods = [
99+
InterfaceMethod<[{
100+
Get the exact flag for the operation.
101+
}], "bool", "getIsExact", (ins), [{}], [{
102+
return $_op.getProperties().isExact;
103+
}]>,
104+
InterfaceMethod<[{
105+
Set the exact flag for the operation.
106+
}], "void", "setIsExact", (ins "bool":$isExact), [{}], [{
107+
$_op.getProperties().isExact = isExact;
108+
}]>,
109+
StaticInterfaceMethod<[{
110+
Get the attribute name of the isExact property.
111+
}], "StringRef", "getIsExactName", (ins), [{}], [{
112+
return "isExact";
113+
}]>,
114+
];
115+
}
116+
90117
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
91118
let description = [{
92119
An interface for operations that can carry branch weights metadata. It

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,24 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
7676
"$res = builder.Create" # instName #
7777
"($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
7878
}
79+
class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
80+
list<Trait> traits = []> :
81+
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
82+
!listconcat([DeclareOpInterfaceMethods<ExactFlagInterface>], traits)> {
83+
let arguments = !con(commonArgs, (ins UnitAttr:$isExact));
84+
85+
string mlirBuilder = [{
86+
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
87+
moduleImport.setExactFlag(inst, op);
88+
$res = op;
89+
}];
90+
let assemblyFormat = [{
91+
(`exact` $isExact^)? $lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)
92+
}];
93+
string llvmBuilder =
94+
"$res = builder.Create" # instName #
95+
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
96+
}
7997
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
8098
list<Trait> traits = []> :
8199
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -116,8 +134,8 @@ def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add",
116134
def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>;
117135
def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul",
118136
[Commutative]>;
119-
def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
120-
def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
137+
def LLVM_UDivOp : LLVM_IntArithmeticOpWithExactFlag<"udiv", "UDiv">;
138+
def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
121139
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
122140
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
123141
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
@@ -128,8 +146,8 @@ def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
128146
def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> {
129147
let hasFolder = 1;
130148
}
131-
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
132-
def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">;
149+
def LLVM_LShrOp : LLVM_IntArithmeticOpWithExactFlag<"lshr", "LShr">;
150+
def LLVM_AShrOp : LLVM_IntArithmeticOpWithExactFlag<"ashr", "AShr">;
133151

134152
// Base class for compare operations. A compare operation takes two operands
135153
// of the same type and returns a boolean result. If the operands are

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ class ModuleImport {
187187
/// operation does not implement the integer overflow flag interface.
188188
void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
189189

190+
/// Sets the exact flag attribute for the imported operation `op` given
191+
/// the original instruction `inst`. Asserts if the operation does not
192+
/// implement the exact flag interface.
193+
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
194+
190195
/// Sets the fastmath flags attribute for the imported operation `op` given
191196
/// the original instruction `inst`. Asserts if the operation does not
192197
/// implement the fastmath interface.

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
143143
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
144144
printer.printOptionalAttrDict(
145145
filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()});
146+
} else if (auto iface = dyn_cast<ExactFlagInterface>(op)) {
147+
printer.printOptionalAttrDict(filteredAttrs,
148+
/*elidedAttrs=*/{iface.getIsExactName()});
146149
} else {
147150
printer.printOptionalAttrDict(filteredAttrs);
148151
}

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,12 @@ void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
683683
iface.setOverflowFlags(value);
684684
}
685685

686+
void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
687+
auto iface = cast<ExactFlagInterface>(op);
688+
689+
iface.setIsExact(inst->isExact());
690+
}
691+
686692
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
687693
Operation *op) const {
688694
auto iface = cast<FastmathFlagsInterface>(op);

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
4949
%mul_flag = llvm.mul %arg0, %arg0 overflow<nsw, nuw> : i32
5050
%shl_flag = llvm.shl %arg0, %arg0 overflow<nuw, nsw> : i32
5151

52+
// Integer exact flag.
53+
// CHECK: {{.*}} = llvm.sdiv exact %[[I32]], %[[I32]] : i32
54+
// CHECK: {{.*}} = llvm.udiv exact %[[I32]], %[[I32]] : i32
55+
// CHECK: {{.*}} = llvm.ashr exact %[[I32]], %[[I32]] : i32
56+
// CHECK: {{.*}} = llvm.lshr exact %[[I32]], %[[I32]] : i32
57+
%sdiv_flag = llvm.sdiv exact %arg0, %arg0 : i32
58+
%udiv_flag = llvm.udiv exact %arg0, %arg0 : i32
59+
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
60+
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
61+
5262
// Floating point binary operations.
5363
//
5464
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
2+
3+
; CHECK-LABEL: @exactflag_inst
4+
define void @exactflag_inst(i64 %arg1, i64 %arg2) {
5+
; CHECK: llvm.udiv exact %{{.*}}, %{{.*}} : i64
6+
%1 = udiv exact i64 %arg1, %arg2
7+
; CHECK: llvm.sdiv exact %{{.*}}, %{{.*}} : i64
8+
%2 = sdiv exact i64 %arg1, %arg2
9+
; CHECK: llvm.lshr exact %{{.*}}, %{{.*}} : i64
10+
%3 = lshr exact i64 %arg1, %arg2
11+
; CHECK: llvm.ashr exact %{{.*}}, %{{.*}} : i64
12+
%4 = ashr exact i64 %arg1, %arg2
13+
ret void
14+
}

mlir/test/Target/LLVMIR/exact.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: define void @exactflag_func
4+
llvm.func @exactflag_func(%arg0: i64, %arg1: i64) {
5+
// CHECK: %{{.*}} = udiv exact i64 %{{.*}}, %{{.*}}
6+
%0 = llvm.udiv exact %arg0, %arg1 : i64
7+
// CHECK: %{{.*}} = sdiv exact i64 %{{.*}}, %{{.*}}
8+
%1 = llvm.sdiv exact %arg0, %arg1 : i64
9+
// CHECK: %{{.*}} = lshr exact i64 %{{.*}}, %{{.*}}
10+
%2 = llvm.lshr exact %arg0, %arg1 : i64
11+
// CHECK: %{{.*}} = ashr exact i64 %{{.*}}, %{{.*}}
12+
%3 = llvm.ashr exact %arg0, %arg1 : i64
13+
llvm.return
14+
}

0 commit comments

Comments
 (0)