Skip to content

[mlir][LLVM] Add exact flag #115327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,33 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface">
];
}

def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
let description = [{
This interface defines an LLVM operation with an exact flag and
provides a uniform API for accessing it.
}];

let cppNamespace = "::mlir::LLVM";

let methods = [
InterfaceMethod<[{
Get the exact flag for the operation.
}], "bool", "getIsExact", (ins), [{}], [{
return $_op.getProperties().isExact;
}]>,
InterfaceMethod<[{
Set the exact flag for the operation.
}], "void", "setIsExact", (ins "bool":$isExact), [{}], [{
$_op.getProperties().isExact = isExact;
}]>,
StaticInterfaceMethod<[{
Get the attribute name of the isExact property.
}], "StringRef", "getIsExactName", (ins), [{}], [{
return "isExact";
}]>,
];
}

def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
Expand Down
26 changes: 22 additions & 4 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
}
class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<ExactFlagInterface>], traits)> {
let arguments = !con(commonArgs, (ins UnitAttr:$isExact));

string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
moduleImport.setExactFlag(inst, op);
$res = op;
}];
let assemblyFormat = [{
(`exact` $isExact^)? $lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)
}];
string llvmBuilder =
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
Expand Down Expand Up @@ -116,8 +134,8 @@ def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add",
def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>;
def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul",
[Commutative]>;
def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
def LLVM_UDivOp : LLVM_IntArithmeticOpWithExactFlag<"udiv", "UDiv">;
def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
Expand All @@ -128,8 +146,8 @@ def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> {
let hasFolder = 1;
}
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">;
def LLVM_LShrOp : LLVM_IntArithmeticOpWithExactFlag<"lshr", "LShr">;
def LLVM_AShrOp : LLVM_IntArithmeticOpWithExactFlag<"ashr", "AShr">;

// Base class for compare operations. A compare operation takes two operands
// of the same type and returns a boolean result. If the operands are
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ class ModuleImport {
/// operation does not implement the integer overflow flag interface.
void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;

/// Sets the exact flag attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the exact flag interface.
void setExactFlag(llvm::Instruction *inst, Operation *op) const;

/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
printer.printOptionalAttrDict(
filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()});
} else if (auto iface = dyn_cast<ExactFlagInterface>(op)) {
printer.printOptionalAttrDict(filteredAttrs,
/*elidedAttrs=*/{iface.getIsExactName()});
} else {
printer.printOptionalAttrDict(filteredAttrs);
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,12 @@ void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
iface.setOverflowFlags(value);
}

void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
auto iface = cast<ExactFlagInterface>(op);

iface.setIsExact(inst->isExact());
}

void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
%mul_flag = llvm.mul %arg0, %arg0 overflow<nsw, nuw> : i32
%shl_flag = llvm.shl %arg0, %arg0 overflow<nuw, nsw> : i32

// Integer exact flag.
// CHECK: {{.*}} = llvm.sdiv exact %[[I32]], %[[I32]] : i32
// CHECK: {{.*}} = llvm.udiv exact %[[I32]], %[[I32]] : i32
// CHECK: {{.*}} = llvm.ashr exact %[[I32]], %[[I32]] : i32
// CHECK: {{.*}} = llvm.lshr exact %[[I32]], %[[I32]] : i32
%sdiv_flag = llvm.sdiv exact %arg0, %arg0 : i32
%udiv_flag = llvm.udiv exact %arg0, %arg0 : i32
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32

// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Target/LLVMIR/Import/exact.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s

; CHECK-LABEL: @exactflag_inst
define void @exactflag_inst(i64 %arg1, i64 %arg2) {
; CHECK: llvm.udiv exact %{{.*}}, %{{.*}} : i64
%1 = udiv exact i64 %arg1, %arg2
; CHECK: llvm.sdiv exact %{{.*}}, %{{.*}} : i64
%2 = sdiv exact i64 %arg1, %arg2
; CHECK: llvm.lshr exact %{{.*}}, %{{.*}} : i64
%3 = lshr exact i64 %arg1, %arg2
; CHECK: llvm.ashr exact %{{.*}}, %{{.*}} : i64
%4 = ashr exact i64 %arg1, %arg2
ret void
}
14 changes: 14 additions & 0 deletions mlir/test/Target/LLVMIR/exact.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// CHECK-LABEL: define void @exactflag_func
llvm.func @exactflag_func(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = udiv exact i64 %{{.*}}, %{{.*}}
%0 = llvm.udiv exact %arg0, %arg1 : i64
// CHECK: %{{.*}} = sdiv exact i64 %{{.*}}, %{{.*}}
%1 = llvm.sdiv exact %arg0, %arg1 : i64
// CHECK: %{{.*}} = lshr exact i64 %{{.*}}, %{{.*}}
%2 = llvm.lshr exact %arg0, %arg1 : i64
// CHECK: %{{.*}} = ashr exact i64 %{{.*}}, %{{.*}}
%3 = llvm.ashr exact %arg0, %arg1 : i64
llvm.return
}
Loading