Skip to content

Commit 66b76fa

Browse files
authored
AMDGPU: Directly emit sqrt intrinsic when folding rootn(x, 2) (#92598)
This avoids depending on pre/post link runs. Depends #92595
1 parent e411c88 commit 66b76fa

File tree

3 files changed

+64
-47
lines changed

3 files changed

+64
-47
lines changed

llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/IR/IRBuilder.h"
2323
#include "llvm/IR/IntrinsicInst.h"
2424
#include "llvm/IR/IntrinsicsAMDGPU.h"
25+
#include "llvm/IR/MDBuilder.h"
2526
#include "llvm/IR/PatternMatch.h"
2627
#include "llvm/InitializePasses.h"
2728
#include <cmath>
@@ -1175,17 +1176,30 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
11751176
return true;
11761177
}
11771178

1178-
Module *M = Parent->getParent();
1179-
if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
1180-
if (FunctionCallee FPExpr =
1181-
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1182-
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0
1183-
<< ")\n");
1184-
Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
1185-
replaceCall(FPOp, nval);
1186-
return true;
1187-
}
1188-
} else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
1179+
Module *M = B.GetInsertBlock()->getModule();
1180+
1181+
CallInst *CI = cast<CallInst>(FPOp);
1182+
if (ci_opr1 == 2 &&
1183+
shouldReplaceLibcallWithIntrinsic(CI,
1184+
/*AllowMinSizeF32=*/true,
1185+
/*AllowF64=*/true)) {
1186+
// rootn(x, 2) = sqrt(x)
1187+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0 << ")\n");
1188+
1189+
CallInst *NewCall = B.CreateUnaryIntrinsic(Intrinsic::sqrt, opr0, CI);
1190+
NewCall->takeName(CI);
1191+
1192+
// OpenCL rootn has a looser ulp of 2 requirement than sqrt, so add some
1193+
// metadata.
1194+
MDBuilder MDHelper(M->getContext());
1195+
MDNode *FPMD = MDHelper.createFPMath(std::max(FPOp->getFPAccuracy(), 2.0f));
1196+
NewCall->setMetadata(LLVMContext::MD_fpmath, FPMD);
1197+
1198+
replaceCall(CI, NewCall);
1199+
return true;
1200+
}
1201+
1202+
if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
11891203
if (FunctionCallee FPExpr =
11901204
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) {
11911205
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0

llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-rootn.ll

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ define half @test_rootn_f16_1(half %x) {
272272
define half @test_rootn_f16_2(half %x) {
273273
; CHECK-LABEL: define half @test_rootn_f16_2(
274274
; CHECK-SAME: half [[X:%.*]]) {
275-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call half @_Z4sqrtDh(half [[X]])
276-
; CHECK-NEXT: ret half [[__ROOTN2SQRT]]
275+
; CHECK-NEXT: [[CALL:%.*]] = call half @llvm.sqrt.f16(half [[X]]), !fpmath [[META0:![0-9]+]]
276+
; CHECK-NEXT: ret half [[CALL]]
277277
;
278278
%call = tail call half @_Z5rootnDhi(half %x, i32 2)
279279
ret half %call
@@ -351,8 +351,8 @@ define <2 x half> @test_rootn_v2f16_1(<2 x half> %x) {
351351
define <2 x half> @test_rootn_v2f16_2(<2 x half> %x) {
352352
; CHECK-LABEL: define <2 x half> @test_rootn_v2f16_2(
353353
; CHECK-SAME: <2 x half> [[X:%.*]]) {
354-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <2 x half> @_Z4sqrtDv2_Dh(<2 x half> [[X]])
355-
; CHECK-NEXT: ret <2 x half> [[__ROOTN2SQRT]]
354+
; CHECK-NEXT: [[CALL:%.*]] = call <2 x half> @llvm.sqrt.v2f16(<2 x half> [[X]]), !fpmath [[META0]]
355+
; CHECK-NEXT: ret <2 x half> [[CALL]]
356356
;
357357
%call = tail call <2 x half> @_Z5rootnDv2_DhDv2_i(<2 x half> %x, <2 x i32> <i32 2, i32 2>)
358358
ret <2 x half> %call
@@ -612,8 +612,8 @@ define float @test_rootn_f32__y_2(float %x) {
612612
; CHECK-LABEL: define float @test_rootn_f32__y_2(
613613
; CHECK-SAME: float [[X:%.*]]) {
614614
; CHECK-NEXT: entry:
615-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call float @_Z4sqrtf(float [[X]])
616-
; CHECK-NEXT: ret float [[__ROOTN2SQRT]]
615+
; CHECK-NEXT: [[CALL:%.*]] = call float @llvm.sqrt.f32(float [[X]]), !fpmath [[META0]]
616+
; CHECK-NEXT: ret float [[CALL]]
617617
;
618618
entry:
619619
%call = tail call float @_Z5rootnfi(float %x, i32 2)
@@ -624,8 +624,8 @@ define float @test_rootn_f32__y_2_flags(float %x) {
624624
; CHECK-LABEL: define float @test_rootn_f32__y_2_flags(
625625
; CHECK-SAME: float [[X:%.*]]) {
626626
; CHECK-NEXT: entry:
627-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call nnan nsz float @_Z4sqrtf(float [[X]])
628-
; CHECK-NEXT: ret float [[__ROOTN2SQRT]]
627+
; CHECK-NEXT: [[CALL:%.*]] = call nnan nsz float @llvm.sqrt.f32(float [[X]]), !fpmath [[META0]]
628+
; CHECK-NEXT: ret float [[CALL]]
629629
;
630630
entry:
631631
%call = tail call nnan nsz float @_Z5rootnfi(float %x, i32 2)
@@ -637,8 +637,8 @@ define float @test_rootn_f32__y_2_fpmath_3(float %x) {
637637
; CHECK-LABEL: define float @test_rootn_f32__y_2_fpmath_3(
638638
; CHECK-SAME: float [[X:%.*]]) {
639639
; CHECK-NEXT: entry:
640-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call nnan nsz float @_Z4sqrtf(float [[X]])
641-
; CHECK-NEXT: ret float [[__ROOTN2SQRT]]
640+
; CHECK-NEXT: [[CALL:%.*]] = call nnan nsz float @llvm.sqrt.f32(float [[X]]), !fpmath [[META1:![0-9]+]]
641+
; CHECK-NEXT: ret float [[CALL]]
642642
;
643643
entry:
644644
%call = tail call nnan nsz float @_Z5rootnfi(float %x, i32 2), !fpmath !0
@@ -649,8 +649,8 @@ define <2 x float> @test_rootn_v2f32__y_2_flags(<2 x float> %x) {
649649
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_2_flags(
650650
; CHECK-SAME: <2 x float> [[X:%.*]]) {
651651
; CHECK-NEXT: entry:
652-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call nnan nsz <2 x float> @_Z4sqrtDv2_f(<2 x float> [[X]])
653-
; CHECK-NEXT: ret <2 x float> [[__ROOTN2SQRT]]
652+
; CHECK-NEXT: [[CALL:%.*]] = call nnan nsz <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]]), !fpmath [[META0]]
653+
; CHECK-NEXT: ret <2 x float> [[CALL]]
654654
;
655655
entry:
656656
%call = tail call nnan nsz <2 x float> @_Z5rootnDv2_fDv2_i(<2 x float> %x, <2 x i32> <i32 2, i32 2>)
@@ -661,8 +661,8 @@ define <3 x float> @test_rootn_v3f32__y_2(<3 x float> %x) {
661661
; CHECK-LABEL: define <3 x float> @test_rootn_v3f32__y_2(
662662
; CHECK-SAME: <3 x float> [[X:%.*]]) {
663663
; CHECK-NEXT: entry:
664-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <3 x float> @_Z4sqrtDv3_f(<3 x float> [[X]])
665-
; CHECK-NEXT: ret <3 x float> [[__ROOTN2SQRT]]
664+
; CHECK-NEXT: [[CALL:%.*]] = call <3 x float> @llvm.sqrt.v3f32(<3 x float> [[X]]), !fpmath [[META0]]
665+
; CHECK-NEXT: ret <3 x float> [[CALL]]
666666
;
667667
entry:
668668
%call = tail call <3 x float> @_Z5rootnDv3_fDv3_i(<3 x float> %x, <3 x i32> <i32 2, i32 2, i32 2>)
@@ -673,8 +673,8 @@ define <3 x float> @test_rootn_v3f32__y_2_undef(<3 x float> %x) {
673673
; CHECK-LABEL: define <3 x float> @test_rootn_v3f32__y_2_undef(
674674
; CHECK-SAME: <3 x float> [[X:%.*]]) {
675675
; CHECK-NEXT: entry:
676-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <3 x float> @_Z4sqrtDv3_f(<3 x float> [[X]])
677-
; CHECK-NEXT: ret <3 x float> [[__ROOTN2SQRT]]
676+
; CHECK-NEXT: [[CALL:%.*]] = call <3 x float> @llvm.sqrt.v3f32(<3 x float> [[X]]), !fpmath [[META0]]
677+
; CHECK-NEXT: ret <3 x float> [[CALL]]
678678
;
679679
entry:
680680
%call = tail call <3 x float> @_Z5rootnDv3_fDv3_i(<3 x float> %x, <3 x i32> <i32 2, i32 poison, i32 2>)
@@ -685,8 +685,8 @@ define <4 x float> @test_rootn_v4f32__y_2(<4 x float> %x) {
685685
; CHECK-LABEL: define <4 x float> @test_rootn_v4f32__y_2(
686686
; CHECK-SAME: <4 x float> [[X:%.*]]) {
687687
; CHECK-NEXT: entry:
688-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <4 x float> @_Z4sqrtDv4_f(<4 x float> [[X]])
689-
; CHECK-NEXT: ret <4 x float> [[__ROOTN2SQRT]]
688+
; CHECK-NEXT: [[CALL:%.*]] = call <4 x float> @llvm.sqrt.v4f32(<4 x float> [[X]]), !fpmath [[META0]]
689+
; CHECK-NEXT: ret <4 x float> [[CALL]]
690690
;
691691
entry:
692692
%call = tail call <4 x float> @_Z5rootnDv4_fDv4_i(<4 x float> %x, <4 x i32> <i32 2, i32 2, i32 2, i32 2>)
@@ -697,8 +697,8 @@ define <8 x float> @test_rootn_v8f32__y_2(<8 x float> %x) {
697697
; CHECK-LABEL: define <8 x float> @test_rootn_v8f32__y_2(
698698
; CHECK-SAME: <8 x float> [[X:%.*]]) {
699699
; CHECK-NEXT: entry:
700-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <8 x float> @_Z4sqrtDv8_f(<8 x float> [[X]])
701-
; CHECK-NEXT: ret <8 x float> [[__ROOTN2SQRT]]
700+
; CHECK-NEXT: [[CALL:%.*]] = call <8 x float> @llvm.sqrt.v8f32(<8 x float> [[X]]), !fpmath [[META0]]
701+
; CHECK-NEXT: ret <8 x float> [[CALL]]
702702
;
703703
entry:
704704
%call = tail call <8 x float> @_Z5rootnDv8_fDv8_i(<8 x float> %x, <8 x i32> <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>)
@@ -709,8 +709,8 @@ define <16 x float> @test_rootn_v16f32__y_2(<16 x float> %x) {
709709
; CHECK-LABEL: define <16 x float> @test_rootn_v16f32__y_2(
710710
; CHECK-SAME: <16 x float> [[X:%.*]]) {
711711
; CHECK-NEXT: entry:
712-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <16 x float> @_Z4sqrtDv16_f(<16 x float> [[X]])
713-
; CHECK-NEXT: ret <16 x float> [[__ROOTN2SQRT]]
712+
; CHECK-NEXT: [[CALL:%.*]] = call <16 x float> @llvm.sqrt.v16f32(<16 x float> [[X]]), !fpmath [[META0]]
713+
; CHECK-NEXT: ret <16 x float> [[CALL]]
714714
;
715715
entry:
716716
%call = tail call <16 x float> @_Z5rootnDv16_fDv16_i(<16 x float> %x, <16 x i32> <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>)
@@ -757,8 +757,8 @@ define <2 x float> @test_rootn_v2f32__y_nonsplat_2_poison(<2 x float> %x) {
757757
; CHECK-LABEL: define <2 x float> @test_rootn_v2f32__y_nonsplat_2_poison(
758758
; CHECK-SAME: <2 x float> [[X:%.*]]) {
759759
; CHECK-NEXT: entry:
760-
; CHECK-NEXT: [[__ROOTN2SQRT:%.*]] = call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[X]])
761-
; CHECK-NEXT: ret <2 x float> [[__ROOTN2SQRT]]
760+
; CHECK-NEXT: [[CALL:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]]), !fpmath [[META0]]
761+
; CHECK-NEXT: ret <2 x float> [[CALL]]
762762
;
763763
entry:
764764
%call = tail call <2 x float> @_Z5rootnDv2_fDv2_i(<2 x float> %x, <2 x i32> <i32 2, i32 poison>)
@@ -913,7 +913,7 @@ define float @test_rootn_f32__y_neg2__nobuiltin(float %x) {
913913
; CHECK-LABEL: define float @test_rootn_f32__y_neg2__nobuiltin(
914914
; CHECK-SAME: float [[X:%.*]]) {
915915
; CHECK-NEXT: entry:
916-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR2:[0-9]+]]
916+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3:[0-9]+]]
917917
; CHECK-NEXT: ret float [[CALL]]
918918
;
919919
entry:
@@ -1125,7 +1125,7 @@ define float @test_rootn_fast_f32_nobuiltin(float %x, i32 %y) {
11251125
; CHECK-LABEL: define float @test_rootn_fast_f32_nobuiltin(
11261126
; CHECK-SAME: float [[X:%.*]], i32 [[Y:%.*]]) {
11271127
; CHECK-NEXT: entry:
1128-
; CHECK-NEXT: [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR2]]
1128+
; CHECK-NEXT: [[CALL:%.*]] = tail call fast float @_Z5rootnfi(float [[X]], i32 [[Y]]) #[[ATTR3]]
11291129
; CHECK-NEXT: ret float [[CALL]]
11301130
;
11311131
entry:
@@ -1420,7 +1420,7 @@ entry:
14201420
define float @test_rootn_f32__y_0_nobuiltin(float %x) {
14211421
; CHECK-LABEL: define float @test_rootn_f32__y_0_nobuiltin(
14221422
; CHECK-SAME: float [[X:%.*]]) {
1423-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR2]]
1423+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 0) #[[ATTR3]]
14241424
; CHECK-NEXT: ret float [[CALL]]
14251425
;
14261426
%call = tail call float @_Z5rootnfi(float %x, i32 0) #0
@@ -1430,7 +1430,7 @@ define float @test_rootn_f32__y_0_nobuiltin(float %x) {
14301430
define float @test_rootn_f32__y_1_nobuiltin(float %x) {
14311431
; CHECK-LABEL: define float @test_rootn_f32__y_1_nobuiltin(
14321432
; CHECK-SAME: float [[X:%.*]]) {
1433-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR2]]
1433+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 1) #[[ATTR3]]
14341434
; CHECK-NEXT: ret float [[CALL]]
14351435
;
14361436
%call = tail call float @_Z5rootnfi(float %x, i32 1) #0
@@ -1440,7 +1440,7 @@ define float @test_rootn_f32__y_1_nobuiltin(float %x) {
14401440
define float @test_rootn_f32__y_2_nobuiltin(float %x) {
14411441
; CHECK-LABEL: define float @test_rootn_f32__y_2_nobuiltin(
14421442
; CHECK-SAME: float [[X:%.*]]) {
1443-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR2]]
1443+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 2) #[[ATTR3]]
14441444
; CHECK-NEXT: ret float [[CALL]]
14451445
;
14461446
%call = tail call float @_Z5rootnfi(float %x, i32 2) #0
@@ -1450,7 +1450,7 @@ define float @test_rootn_f32__y_2_nobuiltin(float %x) {
14501450
define float @test_rootn_f32__y_3_nobuiltin(float %x) {
14511451
; CHECK-LABEL: define float @test_rootn_f32__y_3_nobuiltin(
14521452
; CHECK-SAME: float [[X:%.*]]) {
1453-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR2]]
1453+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 3) #[[ATTR3]]
14541454
; CHECK-NEXT: ret float [[CALL]]
14551455
;
14561456
%call = tail call float @_Z5rootnfi(float %x, i32 3) #0
@@ -1460,7 +1460,7 @@ define float @test_rootn_f32__y_3_nobuiltin(float %x) {
14601460
define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
14611461
; CHECK-LABEL: define float @test_rootn_f32__y_neg1_nobuiltin(
14621462
; CHECK-SAME: float [[X:%.*]]) {
1463-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR2]]
1463+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -1) #[[ATTR3]]
14641464
; CHECK-NEXT: ret float [[CALL]]
14651465
;
14661466
%call = tail call float @_Z5rootnfi(float %x, i32 -1) #0
@@ -1470,7 +1470,7 @@ define float @test_rootn_f32__y_neg1_nobuiltin(float %x) {
14701470
define float @test_rootn_f32__y_neg2_nobuiltin(float %x) {
14711471
; CHECK-LABEL: define float @test_rootn_f32__y_neg2_nobuiltin(
14721472
; CHECK-SAME: float [[X:%.*]]) {
1473-
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR2]]
1473+
; CHECK-NEXT: [[CALL:%.*]] = tail call float @_Z5rootnfi(float [[X]], i32 -2) #[[ATTR3]]
14741474
; CHECK-NEXT: ret float [[CALL]]
14751475
;
14761476
%call = tail call float @_Z5rootnfi(float %x, i32 -2) #0
@@ -1485,6 +1485,10 @@ attributes #2 = { noinline }
14851485
!0 = !{float 3.0}
14861486
;.
14871487
; CHECK: attributes #[[ATTR0]] = { strictfp }
1488-
; CHECK: attributes #[[ATTR1:[0-9]+]] = { nounwind memory(read) }
1489-
; CHECK: attributes #[[ATTR2]] = { nobuiltin }
1488+
; CHECK: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
1489+
; CHECK: attributes #[[ATTR2:[0-9]+]] = { nounwind memory(read) }
1490+
; CHECK: attributes #[[ATTR3]] = { nobuiltin }
1491+
;.
1492+
; CHECK: [[META0]] = !{float 2.000000e+00}
1493+
; CHECK: [[META1]] = !{float 3.000000e+00}
14901494
;.

llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,7 @@ entry:
475475
declare float @_Z5rootnfi(float, i32)
476476

477477
; GCN-LABEL: {{^}}define amdgpu_kernel void @test_rootn_2
478-
; GCN-POSTLINK: call fast float @_Z5rootnfi(float %tmp, i32 2)
479-
; GCN-PRELINK: %__rootn2sqrt = tail call fast float @llvm.sqrt.f32(float %tmp)
478+
; GCN: call fast float @llvm.sqrt.f32(float %tmp)
480479
define amdgpu_kernel void @test_rootn_2(ptr addrspace(1) nocapture %a) {
481480
entry:
482481
%tmp = load float, ptr addrspace(1) %a, align 4

0 commit comments

Comments
 (0)