Skip to content

[VP] Correct lowering of predicated fma and faddmul to avoid strictfp. #85272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2024

Conversation

kpneal
Copy link
Member

@kpneal kpneal commented Mar 14, 2024

Correct missing cases in a switch that result in @llvm.vp.fma.v4f32 getting lowered to a constrained fma intrinsic. Vector predicated lowering to contrained intrinsics is not supported currently, and there's no consensus on the path forward. We certainly shouldn't be introducing constrained intrinsics into a function that isn't strictfp.

Problem found with D146845.

Correct missing cases in a switch that result in @llvm.vp.fma.v4f32
getting lowered to a constrained fma intrinsic. Vector predicated
lowering to contrained intrinsics is not supported currently, and
there's no consensus on the path forward. We certainly shouldn't be
introducing constrained intrinsics into a function that isn't strictfp.

Problem found with D146845.
@llvmbot
Copy link
Member

llvmbot commented Mar 14, 2024

@llvm/pr-subscribers-llvm-ir

Author: Kevin P. Neal (kpneal)

Changes

Correct missing cases in a switch that result in @llvm.vp.fma.v4f32 getting lowered to a constrained fma intrinsic. Vector predicated lowering to contrained intrinsics is not supported currently, and there's no consensus on the path forward. We certainly shouldn't be introducing constrained intrinsics into a function that isn't strictfp.

Problem found with D146845.


Full diff: https://github.com/llvm/llvm-project/pull/85272.diff

4 Files Affected:

  • (modified) llvm/include/llvm/IR/Intrinsics.h (+4)
  • (modified) llvm/lib/CodeGen/ExpandVectorPredication.cpp (+10-2)
  • (modified) llvm/lib/IR/Function.cpp (+13-9)
  • (added) llvm/test/CodeGen/Generic/expand-vp-fp-intrinsics.ll (+176)
diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h
index 0dfe9f029f9b1a..92eae344ce729e 100644
--- a/llvm/include/llvm/IR/Intrinsics.h
+++ b/llvm/include/llvm/IR/Intrinsics.h
@@ -105,6 +105,10 @@ namespace Intrinsic {
   /// Map a MS builtin name to an intrinsic ID.
   ID getIntrinsicForMSBuiltin(const char *Prefix, StringRef BuiltinName);
 
+  /// Returns true if the intrinsic ID is for one of the "Constrained
+  /// Floating-Point Intrinsics".
+  bool isConstrainedFPIntrinsic(ID QID);
+
   /// This is a type descriptor which explains the type requirements of an
   /// intrinsic. This is returned by getIntrinsicInfoTableEntries.
   struct IITDescriptor {
diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
index 0fe4cfefdb1600..8e623c85b737b0 100644
--- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp
+++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
@@ -340,6 +340,8 @@ Value *CachingVPExpander::expandPredicationToFPCall(
     replaceOperation(*NewOp, VPI);
     return NewOp;
   }
+  case Intrinsic::fma:
+  case Intrinsic::fmuladd:
   case Intrinsic::experimental_constrained_fma:
   case Intrinsic::experimental_constrained_fmuladd: {
     Value *Op0 = VPI.getOperand(0);
@@ -347,8 +349,12 @@ Value *CachingVPExpander::expandPredicationToFPCall(
     Value *Op2 = VPI.getOperand(2);
     Function *Fn = Intrinsic::getDeclaration(
         VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});
-    Value *NewOp =
-        Builder.CreateConstrainedFPCall(Fn, {Op0, Op1, Op2}, VPI.getName());
+    Value *NewOp;
+    if (Intrinsic::isConstrainedFPIntrinsic(UnpredicatedIntrinsicID))
+      NewOp =
+          Builder.CreateConstrainedFPCall(Fn, {Op0, Op1, Op2}, VPI.getName());
+    else
+      NewOp = Builder.CreateCall(Fn, {Op0, Op1, Op2}, VPI.getName());
     replaceOperation(*NewOp, VPI);
     return NewOp;
   }
@@ -731,6 +737,8 @@ Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
   case Intrinsic::vp_minnum:
   case Intrinsic::vp_maximum:
   case Intrinsic::vp_minimum:
+  case Intrinsic::vp_fma:
+  case Intrinsic::vp_fmuladd:
     return expandPredicationToFPCall(Builder, VPI,
                                      VPI.getFunctionalIntrinsicID().value());
   case Intrinsic::vp_load:
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index d22e1c12311189..539d7ab5384ea2 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -485,15 +485,7 @@ static MutableArrayRef<Argument> makeArgArray(Argument *Args, size_t Count) {
 }
 
 bool Function::isConstrainedFPIntrinsic() const {
-  switch (getIntrinsicID()) {
-#define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC)                         \
-  case Intrinsic::INTRINSIC:
-#include "llvm/IR/ConstrainedOps.def"
-    return true;
-#undef INSTRUCTION
-  default:
-    return false;
-  }
+  return Intrinsic::isConstrainedFPIntrinsic(getIntrinsicID());
 }
 
 void Function::clearArguments() {
@@ -1468,6 +1460,18 @@ Function *Intrinsic::getDeclaration(Module *M, ID id, ArrayRef<Type*> Tys) {
 #include "llvm/IR/IntrinsicImpl.inc"
 #undef GET_LLVM_INTRINSIC_FOR_MS_BUILTIN
 
+bool Intrinsic::isConstrainedFPIntrinsic(ID QID) {
+  switch (QID) {
+#define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC)                         \
+  case Intrinsic::INTRINSIC:
+#include "llvm/IR/ConstrainedOps.def"
+    return true;
+#undef INSTRUCTION
+  default:
+    return false;
+  }
+}
+
 using DeferredIntrinsicMatchPair =
     std::pair<Type *, ArrayRef<Intrinsic::IITDescriptor>>;
 
diff --git a/llvm/test/CodeGen/Generic/expand-vp-fp-intrinsics.ll b/llvm/test/CodeGen/Generic/expand-vp-fp-intrinsics.ll
new file mode 100644
index 00000000000000..bc89ddea6b85aa
--- /dev/null
+++ b/llvm/test/CodeGen/Generic/expand-vp-fp-intrinsics.ll
@@ -0,0 +1,176 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -expandvp -S < %s | FileCheck %s
+
+define void @vp_fadd_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
+; CHECK-LABEL: define void @vp_fadd_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:    [[RES1:%.*]] = fadd <4 x float> [[A0]], [[A1]]
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.fadd.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.fadd.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)
+
+define void @vp_fsub_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
+; CHECK-LABEL: define void @vp_fsub_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[RES1:%.*]] = fsub <4 x float> [[A0]], [[A1]]
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.fsub.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)
+
+define void @vp_fmul_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
+; CHECK-LABEL: define void @vp_fmul_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[RES1:%.*]] = fmul <4 x float> [[A0]], [[A1]]
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.fmul.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.fmul.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)
+
+define void @vp_fdiv_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
+; CHECK-LABEL: define void @vp_fdiv_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[RES1:%.*]] = fdiv <4 x float> [[A0]], [[A1]]
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.fdiv.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.fdiv.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)
+
+define void @vp_frem_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
+; CHECK-LABEL: define void @vp_frem_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[RES1:%.*]] = frem <4 x float> [[A0]], [[A1]]
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.frem.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.frem.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)
+
+define void @vp_fabs_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
+; CHECK-LABEL: define void @vp_fabs_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[RES1:%.*]] = call <4 x float> @llvm.fabs.v4f32(<4 x float> [[A0]])
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.fabs.v4f32(<4 x float> %a0, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.fabs.v4f32(<4 x float>, <4 x i1>, i32)
+
+define void @vp_sqrt_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
+; CHECK-LABEL: define void @vp_sqrt_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[RES1:%.*]] = call <4 x float> @llvm.sqrt.v4f32(<4 x float> [[A0]])
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.sqrt.v4f32(<4 x float> %a0, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.sqrt.v4f32(<4 x float>, <4 x i1>, i32)
+
+define void @vp_fneg_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i32 %vp) nounwind {
+; CHECK-LABEL: define void @vp_fneg_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i32 [[VP:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[RES1:%.*]] = fneg <4 x float> [[A0]]
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.fneg.v4f32(<4 x float> %a0, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 %vp)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.fneg.v4f32(<4 x float>, <4 x i1>, i32)
+
+define void @vp_fma_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i4 %a5) nounwind {
+; CHECK-LABEL: define void @vp_fma_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i4 [[A5:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[RES1:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[A0]], <4 x float> [[A1]], <4 x float> [[A1]])
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.fma.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 4)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.fma.v4f32(<4 x float>, <4 x float>, <4 x float>, <4 x i1>, i32)
+
+define void @vp_fmuladd_v4f32(<4 x float> %a0, <4 x float> %a1, ptr %out, i4 %a5) nounwind {
+; CHECK-LABEL: define void @vp_fmuladd_v4f32(
+; CHECK-SAME: <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], ptr [[OUT:%.*]], i4 [[A5:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:    [[RES1:%.*]] = call <4 x float> @llvm.fmuladd.v4f32(<4 x float> [[A0]], <4 x float> [[A1]], <4 x float> [[A1]])
+; CHECK-NEXT:    store <4 x float> [[RES1]], ptr [[OUT]], align 16
+; CHECK-NEXT:    ret void
+;
+  %res = call <4 x float> @llvm.vp.fmuladd.v4f32(<4 x float> %a0, <4 x float> %a1, <4 x float> %a1, <4 x i1> <i1 -1, i1 -1, i1 -1, i1 -1>, i32 4)
+  store <4 x float> %res, ptr %out
+  ret void
+}
+declare <4 x float> @llvm.vp.fmuladd.v4f32(<4 x float>, <4 x float>, <4 x float>, <4 x i1>, i32)
+
+declare <4 x float> @llvm.vp.maxnum.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)
+define <4 x float> @vfmax_vv_v4f32(<4 x float> %va, <4 x float> %vb, <4 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: define <4 x float> @vfmax_vv_v4f32(
+; CHECK-SAME: <4 x float> [[VA:%.*]], <4 x float> [[VB:%.*]], <4 x i1> [[M:%.*]], i32 zeroext [[EVL:%.*]]) {
+; CHECK-NEXT:    [[V1:%.*]] = call <4 x float> @llvm.maxnum.v4f32(<4 x float> [[VA]], <4 x float> [[VB]])
+; CHECK-NEXT:    ret <4 x float> [[V1]]
+;
+  %v = call <4 x float> @llvm.vp.maxnum.v4f32(<4 x float> %va, <4 x float> %vb, <4 x i1> %m, i32 %evl)
+  ret <4 x float> %v
+}
+
+declare <8 x float> @llvm.vp.maxnum.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32)
+define <8 x float> @vfmax_vv_v8f32(<8 x float> %va, <8 x float> %vb, <8 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: define <8 x float> @vfmax_vv_v8f32(
+; CHECK-SAME: <8 x float> [[VA:%.*]], <8 x float> [[VB:%.*]], <8 x i1> [[M:%.*]], i32 zeroext [[EVL:%.*]]) {
+; CHECK-NEXT:    [[V1:%.*]] = call <8 x float> @llvm.maxnum.v8f32(<8 x float> [[VA]], <8 x float> [[VB]])
+; CHECK-NEXT:    ret <8 x float> [[V1]]
+;
+  %v = call <8 x float> @llvm.vp.maxnum.v8f32(<8 x float> %va, <8 x float> %vb, <8 x i1> %m, i32 %evl)
+  ret <8 x float> %v
+}
+
+declare <4 x float> @llvm.vp.minnum.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32)
+define <4 x float> @vfmin_vv_v4f32(<4 x float> %va, <4 x float> %vb, <4 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: define <4 x float> @vfmin_vv_v4f32(
+; CHECK-SAME: <4 x float> [[VA:%.*]], <4 x float> [[VB:%.*]], <4 x i1> [[M:%.*]], i32 zeroext [[EVL:%.*]]) {
+; CHECK-NEXT:    [[V1:%.*]] = call <4 x float> @llvm.minnum.v4f32(<4 x float> [[VA]], <4 x float> [[VB]])
+; CHECK-NEXT:    ret <4 x float> [[V1]]
+;
+  %v = call <4 x float> @llvm.vp.minnum.v4f32(<4 x float> %va, <4 x float> %vb, <4 x i1> %m, i32 %evl)
+  ret <4 x float> %v
+}
+
+declare <8 x float> @llvm.vp.minnum.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32)
+define <8 x float> @vfmin_vv_v8f32(<8 x float> %va, <8 x float> %vb, <8 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: define <8 x float> @vfmin_vv_v8f32(
+; CHECK-SAME: <8 x float> [[VA:%.*]], <8 x float> [[VB:%.*]], <8 x i1> [[M:%.*]], i32 zeroext [[EVL:%.*]]) {
+; CHECK-NEXT:    [[V1:%.*]] = call <8 x float> @llvm.minnum.v8f32(<8 x float> [[VA]], <8 x float> [[VB]])
+; CHECK-NEXT:    ret <8 x float> [[V1]]
+;
+  %v = call <8 x float> @llvm.vp.minnum.v8f32(<8 x float> %va, <8 x float> %vb, <8 x i1> %m, i32 %evl)
+  ret <8 x float> %v
+}

@kpneal
Copy link
Member Author

kpneal commented Mar 26, 2024

Ping

1 similar comment
@kpneal
Copy link
Member Author

kpneal commented Apr 12, 2024

Ping

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kpneal kpneal merged commit 79726ef into llvm:main Apr 17, 2024
@kpneal kpneal deleted the vp branch April 17, 2024 12:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants