Skip to content

Commit 4290e34

Browse files
authored
[flang][AMDGPU] Convert math ops to AMD GPU library calls instead of libm calls (#99517)
This patch invokes a pass when compiling for an AMDGPU target to lower math operations to AMD GPU library calls library calls instead of libm calls.
1 parent f58312e commit 4290e34

File tree

3 files changed

+196
-1
lines changed

3 files changed

+196
-1
lines changed

flang/lib/Optimizer/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
2626
MLIRMathToFuncs
2727
MLIRMathToLLVM
2828
MLIRMathToLibm
29+
MLIRMathToROCDL
2930
MLIROpenMPToLLVM
3031
MLIROpenACCDialect
3132
MLIRBuiltinToLLVMIRTranslation

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
3737
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
3838
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
39+
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
3940
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
4041
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
4142
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -3671,6 +3672,14 @@ class FIRToLLVMLowering
36713672
// as passes here.
36723673
mlir::OpPassManager mathConvertionPM("builtin.module");
36733674

3675+
bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
3676+
// If compiling for AMD target some math operations must be lowered to AMD
3677+
// GPU library calls, the rest can be converted to LLVM intrinsics, which
3678+
// is handled in the mathToLLVM conversion. The lowering to libm calls is
3679+
// not needed since all math operations are handled this way.
3680+
if (isAMDGCN)
3681+
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
3682+
36743683
// Convert math::FPowI operations to inline implementation
36753684
// only if the exponent's width is greater than 32, otherwise,
36763685
// it will be lowered to LLVM intrinsic operation by a later conversion.
@@ -3710,7 +3719,8 @@ class FIRToLLVMLowering
37103719
pattern);
37113720
// Math operations that have not been converted yet must be converted
37123721
// to Libm.
3713-
mlir::populateMathToLibmConversionPatterns(pattern);
3722+
if (!isAMDGCN)
3723+
mlir::populateMathToLibmConversionPatterns(pattern);
37143724
mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
37153725
mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);
37163726

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
!REQUIRES: amdgpu-registered-target
2+
!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
3+
4+
subroutine omp_pow_f32(x, y)
5+
!$omp declare target
6+
real :: x, y
7+
!CHECK: call float @__ocml_pow_f32(float {{.*}}, float {{.*}})
8+
y = x ** x
9+
end subroutine omp_pow_f32
10+
11+
subroutine omp_pow_f64(x, y)
12+
!$omp declare target
13+
real(8) :: x, y
14+
!CHECK: call double @__ocml_pow_f64(double {{.*}}, double {{.*}})
15+
y = x ** x
16+
end subroutine omp_pow_f64
17+
18+
subroutine omp_sin_f32(x, y)
19+
!$omp declare target
20+
real :: x, y
21+
!CHECK: call float @__ocml_sin_f32(float {{.*}})
22+
y = sin(x)
23+
end subroutine omp_sin_f32
24+
25+
subroutine omp_sin_f64(x, y)
26+
!$omp declare target
27+
real(8) :: x, y
28+
!CHECK: call double @__ocml_sin_f64(double {{.*}})
29+
y = sin(x)
30+
end subroutine omp_sin_f64
31+
32+
subroutine omp_abs_f32(x, y)
33+
!$omp declare target
34+
real :: x, y
35+
!CHECK: call contract float @llvm.fabs.f32(float {{.*}})
36+
y = abs(x)
37+
end subroutine omp_abs_f32
38+
39+
subroutine omp_abs_f64(x, y)
40+
!$omp declare target
41+
real(8) :: x, y
42+
!CHECK: call contract double @llvm.fabs.f64(double {{.*}})
43+
y = abs(x)
44+
end subroutine omp_abs_f64
45+
46+
subroutine omp_atan_f32(x, y)
47+
!$omp declare target
48+
real :: x, y
49+
!CHECK: call float @__ocml_atan_f32(float {{.*}})
50+
y = atan(x)
51+
end subroutine omp_atan_f32
52+
53+
subroutine omp_atan_f64(x, y)
54+
!$omp declare target
55+
real(8) :: x, y
56+
!CHECK: call double @__ocml_atan_f64(double {{.*}})
57+
y = atan(x)
58+
end subroutine omp_atan_f64
59+
60+
subroutine omp_atan2_f32(x, y)
61+
!$omp declare target
62+
real :: x, y
63+
!CHECK: call float @__ocml_atan2_f32(float {{.*}}, float {{.*}})
64+
y = atan2(x, x)
65+
end subroutine omp_atan2_f32
66+
67+
subroutine omp_atan2_f64(x, y)
68+
!$omp declare target
69+
real(8) :: x, y
70+
!CHECK: call double @__ocml_atan2_f64(double {{.*}}, double {{.*}})
71+
y = atan2(x ,x)
72+
end subroutine omp_atan2_f64
73+
74+
subroutine omp_cos_f32(x, y)
75+
!$omp declare target
76+
real :: x, y
77+
!CHECK: call float @__ocml_cos_f32(float {{.*}})
78+
y = cos(x)
79+
end subroutine omp_cos_f32
80+
81+
subroutine omp_cos_f64(x, y)
82+
!$omp declare target
83+
real(8) :: x, y
84+
!CHECK: call double @__ocml_cos_f64(double {{.*}})
85+
y = cos(x)
86+
end subroutine omp_cos_f64
87+
88+
subroutine omp_erf_f32(x, y)
89+
!$omp declare target
90+
real :: x, y
91+
!CHECK: call float @__ocml_erf_f32(float {{.*}})
92+
y = erf(x)
93+
end subroutine omp_erf_f32
94+
95+
subroutine omp_erf_f64(x, y)
96+
!$omp declare target
97+
real(8) :: x, y
98+
!CHECK: call double @__ocml_erf_f64(double {{.*}})
99+
y = erf(x)
100+
end subroutine omp_erf_f64
101+
102+
subroutine omp_exp_f32(x, y)
103+
!$omp declare target
104+
real :: x, y
105+
!CHECK: call contract float @llvm.exp.f32(float {{.*}})
106+
y = exp(x)
107+
end subroutine omp_exp_f32
108+
109+
subroutine omp_exp_f64(x, y)
110+
!$omp declare target
111+
real(8) :: x, y
112+
!CHECK: call double @__ocml_exp_f64(double {{.*}})
113+
y = exp(x)
114+
end subroutine omp_exp_f64
115+
116+
subroutine omp_log_f32(x, y)
117+
!$omp declare target
118+
real :: x, y
119+
!CHECK: call contract float @llvm.log.f32(float {{.*}})
120+
y = log(x)
121+
end subroutine omp_log_f32
122+
123+
subroutine omp_log_f64(x, y)
124+
!$omp declare target
125+
real(8) :: x, y
126+
!CHECK: call double @__ocml_log_f64(double {{.*}})
127+
y = log(x)
128+
end subroutine omp_log_f64
129+
130+
subroutine omp_log10_f32(x, y)
131+
!$omp declare target
132+
real :: x, y
133+
!CHECK: call float @__ocml_log10_f32(float {{.*}})
134+
y = log10(x)
135+
end subroutine omp_log10_f32
136+
137+
subroutine omp_log10_f64(x, y)
138+
!$omp declare target
139+
real(8) :: x, y
140+
!CHECK: call double @__ocml_log10_f64(double {{.*}})
141+
y = log10(x)
142+
end subroutine omp_log10_f64
143+
144+
subroutine omp_sqrt_f32(x, y)
145+
!$omp declare target
146+
real :: x, y
147+
!CHECK: call contract float @llvm.sqrt.f32(float {{.*}})
148+
y = sqrt(x)
149+
end subroutine omp_sqrt_f32
150+
151+
subroutine omp_sqrt_f64(x, y)
152+
!$omp declare target
153+
real(8) :: x, y
154+
!CHECK: call contract double @llvm.sqrt.f64(double {{.*}})
155+
y = sqrt(x)
156+
end subroutine omp_sqrt_f64
157+
158+
subroutine omp_tan_f32(x, y)
159+
!$omp declare target
160+
real :: x, y
161+
!CHECK: call float @__ocml_tan_f32(float {{.*}})
162+
y = tan(x)
163+
end subroutine omp_tan_f32
164+
165+
subroutine omp_tan_f64(x, y)
166+
!$omp declare target
167+
real(8) :: x, y
168+
!CHECK: call double @__ocml_tan_f64(double {{.*}})
169+
y = tan(x)
170+
end subroutine omp_tan_f64
171+
172+
subroutine omp_tanh_f32(x, y)
173+
!$omp declare target
174+
real :: x, y
175+
!CHECK: call float @__ocml_tanh_f32(float {{.*}})
176+
y = tanh(x)
177+
end subroutine omp_tanh_f32
178+
179+
subroutine omp_tanh_f64(x, y)
180+
!$omp declare target
181+
real(8) :: x, y
182+
!CHECK: call double @__ocml_tanh_f64(double {{.*}})
183+
y = tanh(x)
184+
end subroutine omp_tanh_f64

0 commit comments

Comments
 (0)