From 85ce594b9fa4156c8a561db57f38f1de60919dcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Frenot?= Date: Thu, 7 Nov 2024 15:04:32 +0000 Subject: [PATCH 1/2] [mlir][LLVM] Add exact flag --- .../mlir/Dialect/LLVMIR/LLVMInterfaces.td | 27 +++++++++++++++++++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 26 +++++++++++++++--- .../include/mlir/Target/LLVMIR/ModuleImport.h | 5 ++++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 +++ mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 ++++++ mlir/test/Dialect/LLVMIR/roundtrip.mlir | 10 +++++++ mlir/test/Target/LLVMIR/Import/exact.ll | 14 ++++++++++ mlir/test/Target/LLVMIR/exact.mlir | 14 ++++++++++ 8 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/Import/exact.ll create mode 100644 mlir/test/Target/LLVMIR/exact.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index 7e38e0b27fd96..12c430df20892 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -87,6 +87,33 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> ]; } +def ExactFlagInterface : OpInterface<"ExactFlagInterface"> { + let description = [{ + This interface defines an LLVM operation with an exact flag and + provides a uniform API for accessing it. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod<[{ + Get the exact flag for the operation. + }], "bool", "getIsExact", (ins), [{}], [{ + return $_op.getProperties().isExact; + }]>, + InterfaceMethod<[{ + Set the exact flag for the operation. + }], "void", "setIsExact", (ins "bool":$isExact), [{}], [{ + $_op.getProperties().isExact = isExact; + }]>, + StaticInterfaceMethod<[{ + Get the attribute name of the isExact property. + }], "StringRef", "getIsExactName", (ins), [{}], [{ + return "isExact"; + }]>, + ]; +} + def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> { let description = [{ An interface for operations that can carry branch weights metadata. It diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index d5def510a904d..b7ce126dbf54d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -76,6 +76,24 @@ class LLVM_IntArithmeticOpWithOverflowFlag traits = []> : + LLVM_ArithmeticOpBase], traits)> { + let arguments = !con(commonArgs, (ins UnitAttr:$isExact)); + + string mlirBuilder = [{ + auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + moduleImport.setExactFlag(inst, op); + $res = op; + }]; + let assemblyFormat = [{ + (`exact` $isExact^)? $lhs `,` $rhs custom(attr-dict) `:` type($res) + }]; + string llvmBuilder = + "$res = builder.Create" # instName # + "($lhs, $rhs, /*Name=*/\"\", op.getIsExact());"; +} class LLVM_FloatArithmeticOp traits = []> : LLVM_ArithmeticOpBase; def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul", [Commutative]>; -def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">; -def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">; +def LLVM_UDivOp : LLVM_IntArithmeticOpWithIsExact<"udiv", "UDiv">; +def LLVM_SDivOp : LLVM_IntArithmeticOpWithIsExact<"sdiv", "SDiv">; def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">; def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">; def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">; @@ -128,8 +146,8 @@ def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">; def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> { let hasFolder = 1; } -def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">; -def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">; +def LLVM_LShrOp : LLVM_IntArithmeticOpWithIsExact<"lshr", "LShr">; +def LLVM_AShrOp : LLVM_IntArithmeticOpWithIsExact<"ashr", "AShr">; // Base class for compare operations. A compare operation takes two operands // of the same type and returns a boolean result. If the operands are diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index bbb7af58d2739..6c3a500f20e3a 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -187,6 +187,11 @@ class ModuleImport { /// operation does not implement the integer overflow flag interface. void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const; + /// Sets the exact flag attribute for the imported operation `op` given + /// the original instruction `inst`. Asserts if the operation does not + /// implement the exact flag interface. + void setExactFlag(llvm::Instruction *inst, Operation *op) const; + /// Sets the fastmath flags attribute for the imported operation `op` given /// the original instruction `inst`. Asserts if the operation does not /// implement the fastmath interface. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index c9bc9533ca2a6..6b2d8943bf488 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -143,6 +143,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, if (auto iface = dyn_cast(op)) { printer.printOptionalAttrDict( filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()}); + } else if (auto iface = dyn_cast(op)) { + printer.printOptionalAttrDict(filteredAttrs, + /*elidedAttrs=*/{iface.getIsExactName()}); } else { printer.printOptionalAttrDict(filteredAttrs); } diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 1f63519373eca..ccec2034a298b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -683,6 +683,14 @@ void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst, iface.setOverflowFlags(value); } +void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const { + auto iface = cast(op); + + bool value = inst->isExact(); + + iface.setIsExact(value); +} + void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const { auto iface = cast(op); diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index b8ce7db795a1d..9daad2ef5b0b1 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -49,6 +49,16 @@ func.func @ops(%arg0: i32, %arg1: f32, %mul_flag = llvm.mul %arg0, %arg0 overflow : i32 %shl_flag = llvm.shl %arg0, %arg0 overflow : i32 +// Integer exact +// CHECK: {{.*}} = llvm.sdiv exact %[[I32]], %[[I32]] : i32 +// CHECK: {{.*}} = llvm.udiv exact %[[I32]], %[[I32]] : i32 +// CHECK: {{.*}} = llvm.ashr exact %[[I32]], %[[I32]] : i32 +// CHECK: {{.*}} = llvm.lshr exact %[[I32]], %[[I32]] : i32 + %sdiv_flag = llvm.sdiv exact %arg0, %arg0 : i32 + %udiv_flag = llvm.udiv exact %arg0, %arg0 : i32 + %ashr_flag = llvm.ashr exact %arg0, %arg0 : i32 + %lshr_flag = llvm.lshr exact %arg0, %arg0 : i32 + // Floating point binary operations. // // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32 diff --git a/mlir/test/Target/LLVMIR/Import/exact.ll b/mlir/test/Target/LLVMIR/Import/exact.ll new file mode 100644 index 0000000000000..528fee5091d2d --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/exact.ll @@ -0,0 +1,14 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK-LABEL: @exactflag_inst +define void @exactflag_inst(i64 %arg1, i64 %arg2) { + ; CHECK: llvm.udiv exact %{{.*}}, %{{.*}} : i64 + %1 = udiv exact i64 %arg1, %arg2 + ; CHECK: llvm.sdiv exact %{{.*}}, %{{.*}} : i64 + %2 = sdiv exact i64 %arg1, %arg2 + ; CHECK: llvm.lshr exact %{{.*}}, %{{.*}} : i64 + %3 = lshr exact i64 %arg1, %arg2 + ; CHECK: llvm.ashr exact %{{.*}}, %{{.*}} : i64 + %4 = ashr exact i64 %arg1, %arg2 + ret void +} diff --git a/mlir/test/Target/LLVMIR/exact.mlir b/mlir/test/Target/LLVMIR/exact.mlir new file mode 100644 index 0000000000000..b6c378c2fdcc9 --- /dev/null +++ b/mlir/test/Target/LLVMIR/exact.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: define void @exactflag_func +llvm.func @exactflag_func(%arg0: i64, %arg1: i64) { + // CHECK: %{{.*}} = udiv exact i64 %{{.*}}, %{{.*}} + %0 = llvm.udiv exact %arg0, %arg1 : i64 + // CHECK: %{{.*}} = sdiv exact i64 %{{.*}}, %{{.*}} + %1 = llvm.sdiv exact %arg0, %arg1 : i64 + // CHECK: %{{.*}} = lshr exact i64 %{{.*}}, %{{.*}} + %2 = llvm.lshr exact %arg0, %arg1 : i64 + // CHECK: %{{.*}} = ashr exact i64 %{{.*}}, %{{.*}} + %3 = llvm.ashr exact %arg0, %arg1 : i64 + llvm.return +} From b8dcb5fea15067a6e224a69b128b205b2269ed2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Frenot?= Date: Fri, 8 Nov 2024 09:20:01 +0000 Subject: [PATCH 2/2] nit fixes --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 10 +++++----- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 4 +--- mlir/test/Dialect/LLVMIR/roundtrip.mlir | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b7ce126dbf54d..315af2594047a 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -76,7 +76,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag traits = []> : LLVM_ArithmeticOpBase], traits)> { @@ -134,8 +134,8 @@ def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add", def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>; def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul", [Commutative]>; -def LLVM_UDivOp : LLVM_IntArithmeticOpWithIsExact<"udiv", "UDiv">; -def LLVM_SDivOp : LLVM_IntArithmeticOpWithIsExact<"sdiv", "SDiv">; +def LLVM_UDivOp : LLVM_IntArithmeticOpWithExactFlag<"udiv", "UDiv">; +def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">; def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">; def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">; def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">; @@ -146,8 +146,8 @@ def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">; def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> { let hasFolder = 1; } -def LLVM_LShrOp : LLVM_IntArithmeticOpWithIsExact<"lshr", "LShr">; -def LLVM_AShrOp : LLVM_IntArithmeticOpWithIsExact<"ashr", "AShr">; +def LLVM_LShrOp : LLVM_IntArithmeticOpWithExactFlag<"lshr", "LShr">; +def LLVM_AShrOp : LLVM_IntArithmeticOpWithExactFlag<"ashr", "AShr">; // Base class for compare operations. A compare operation takes two operands // of the same type and returns a boolean result. If the operands are diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index ccec2034a298b..70881ed5fd677 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -686,9 +686,7 @@ void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst, void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const { auto iface = cast(op); - bool value = inst->isExact(); - - iface.setIsExact(value); + iface.setIsExact(inst->isExact()); } void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 9daad2ef5b0b1..682780c5f0a7d 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -49,7 +49,7 @@ func.func @ops(%arg0: i32, %arg1: f32, %mul_flag = llvm.mul %arg0, %arg0 overflow : i32 %shl_flag = llvm.shl %arg0, %arg0 overflow : i32 -// Integer exact +// Integer exact flag. // CHECK: {{.*}} = llvm.sdiv exact %[[I32]], %[[I32]] : i32 // CHECK: {{.*}} = llvm.udiv exact %[[I32]], %[[I32]] : i32 // CHECK: {{.*}} = llvm.ashr exact %[[I32]], %[[I32]] : i32