diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td index f05230526c21f..ec835e05258d8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr< let printBitEnumPrimaryGroups = 1; } +//===----------------------------------------------------------------------===// +// IntegerOverflowFlags +//===----------------------------------------------------------------------===// + +def IOFnone : I32BitEnumAttrCaseNone<"none">; +def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>; +def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>; + +def IntegerOverflowFlags : I32BitEnumAttr< + "IntegerOverflowFlags", + "LLVM integer overflow flags", + [IOFnone, IOFnsw, IOFnuw]> { + let separator = ", "; + let cppNamespace = "::mlir::LLVM"; + let genSpecializedAttr = 0; + let printBitEnumPrimaryGroups = 1; +} + +def LLVM_IntegerOverflowFlagsAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // FastmathFlags //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index c5d65f792254e..81589eaf5fd0a 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { ]; } +def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> { + let description = [{ + Access to op integer overflow flags. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation", + /*returnType=*/ "IntegerOverflowFlagsAttr", + /*methodName=*/ "getOverflowAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + return op.getOverflowFlagsAttr(); + }] + >, + InterfaceMethod< + /*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword", + /*returnType=*/ "bool", + /*methodName=*/ "hasNoUnsignedWrap", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); + return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw); + }] + >, + InterfaceMethod< + /*desc=*/ "Returns whether the operation has the No Signed Wrap keyword", + /*returnType=*/ "bool", + /*methodName=*/ "hasNoSignedWrap", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); + return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute + for the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getIntegerOverflowAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "overflowFlags"; + }] + > + ]; +} + def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> { let description = [{ An interface for operations that can carry branch weights metadata. It diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 8f166f0cc7cf5..40f9aa1ce33e5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -55,6 +55,26 @@ class LLVM_IntArithmeticOp($_location, $lhs, $rhs); }]; } +class LLVM_IntArithmeticOpWithOverflowFlag traits = []> : + LLVM_ArithmeticOpBase], traits)> { + dag iofArg = ( + ins DefaultValuedAttr:$overflowFlags); + let arguments = !con(commonArgs, iofArg); + string mlirBuilder = [{ + auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + moduleImport.setIntegerOverflowFlagsAttr(inst, op); + $res = op; + }]; + let assemblyFormat = [{ + $lhs `,` $rhs (`overflow` `` $overflowFlags^)? + custom(attr-dict) `:` type($res) + }]; + string llvmBuilder = + "$res = builder.Create" # instName # + "($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"; +} class LLVM_FloatArithmeticOp traits = []> : LLVM_ArithmeticOpBase; -def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">; -def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>; +def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add", + [Commutative]>; +def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>; +def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul", + [Commutative]>; def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">; def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">; def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">; @@ -102,7 +124,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> { let hasFolder = 1; } def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">; -def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> { +def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> { let hasFolder = 1; } def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index b8e449dc11df1..b49d2f539453e 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -172,6 +172,12 @@ class ModuleImport { /// attributes of LLVMFuncOp `funcOp`. void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp); + /// Sets the integer overflow flags (nsw/nuw) attribute for the imported + /// operation `op` given the original instruction `inst`. Asserts if the + /// operation does not implement the integer overflow flag interface. + void setIntegerOverflowFlagsAttr(llvm::Instruction *inst, + Operation *op) const; + /// Sets the fastmath flags attribute for the imported operation `op` given /// the original instruction `inst`. Asserts if the operation does not /// implement the fastmath interface. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 28445945f07d6..9beb0d6cc3323 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -69,7 +69,13 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, DictionaryAttr attrs) { - printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); + auto filteredAttrs = processFMFAttr(attrs.getValue()); + if (auto iface = dyn_cast(op)) + printer.printOptionalAttrDict( + filteredAttrs, + /*elidedAttrs=*/{iface.getIntegerOverflowAttrName()}); + else + printer.printOptionalAttrDict(filteredAttrs); } /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 2d1aaa9229cd2..ad67a7fbc030d 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst, } } +void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst, + Operation *op) const { + auto iface = cast(op); + + IntegerOverflowFlags value = {}; + value = bitEnumSet(value, IntegerOverflowFlags::nsw, inst->hasNoSignedWrap()); + value = + bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap()); + + auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value); + iface->setAttr(iface.getIntegerOverflowAttrName(), attr); +} + void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const { auto iface = cast(op); diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index ee724a482cfb5..7962f2514cb19 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32, %vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr> %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1> +// Integer overflow flags +// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] overflow : i32 +// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] overflow : i32 +// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] overflow : i32 +// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] overflow : i32 + %add_flag = llvm.add %arg0, %arg0 overflow : i32 + %sub_flag = llvm.sub %arg0, %arg0 overflow : i32 + %mul_flag = llvm.mul %arg0, %arg0 overflow : i32 + %shl_flag = llvm.shl %arg0, %arg0 overflow : i32 + // Floating point binary operations. // // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32 diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll new file mode 100644 index 0000000000000..d08098a5e5dfe --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll @@ -0,0 +1,14 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK-LABEL: @intflag_inst +define void @intflag_inst(i64 %arg1, i64 %arg2) { + ; CHECK: llvm.add %{{.*}}, %{{.*}} overflow : i64 + %1 = add nsw i64 %arg1, %arg2 + ; CHECK: llvm.sub %{{.*}}, %{{.*}} overflow : i64 + %2 = sub nuw i64 %arg1, %arg2 + ; CHECK: llvm.mul %{{.*}}, %{{.*}} overflow : i64 + %3 = mul nsw nuw i64 %arg1, %arg2 + ; CHECK: llvm.shl %{{.*}}, %{{.*}} overflow : i64 + %4 = shl nuw nsw i64 %arg1, %arg2 + ret void +} diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir new file mode 100644 index 0000000000000..6843c2ef0299c --- /dev/null +++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: define void @intflags_func +llvm.func @intflags_func(%arg0: i64, %arg1: i64) { + // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}} + %0 = llvm.add %arg0, %arg1 overflow : i64 + // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}} + %1 = llvm.sub %arg0, %arg1 overflow : i64 + // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}} + %2 = llvm.mul %arg0, %arg1 overflow : i64 + // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}} + %3 = llvm.shl %arg0, %arg1 overflow : i64 + llvm.return +}