From edf7291756d1616fd85067fce5afdf6d726b8024 Mon Sep 17 00:00:00 2001 From: Jan Leyonberg Date: Thu, 18 Jul 2024 11:06:51 -0400 Subject: [PATCH 1/2] [flang][AMDGPU] Convert math ops to AMD GPU library calls instead of libm calls This patch invokes a pass when compiling for an AMDGPU target to lower math operations to AMD GPU library calls library calls instead of libm calls. --- flang/lib/Optimizer/CodeGen/CMakeLists.txt | 1 + flang/lib/Optimizer/CodeGen/CodeGen.cpp | 12 +- flang/test/Lower/OpenMP/math-amdgpu.f90 | 184 +++++++++++++++++++++ 3 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 flang/test/Lower/OpenMP/math-amdgpu.f90 diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt index 650448eee1099..646621cb01c15 100644 --- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt @@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen MLIRMathToFuncs MLIRMathToLLVM MLIRMathToLibm + MLIRMathToROCDL MLIROpenMPToLLVM MLIROpenACCDialect MLIRBuiltinToLLVMIRTranslation diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index ac521ae95df39..88293bcf36a78 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -36,6 +36,7 @@ #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -3671,6 +3672,14 @@ class FIRToLLVMLowering // as passes here. mlir::OpPassManager mathConvertionPM("builtin.module"); + bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN(); + // If compiling for AMD target some math operations must be lowered to AMD + // GPU library calls, the rest can be converted to LLVM intrinsics, which + // is handled in the mathToLLVM conversion. The lowering to libm calls is + // not needed since all math operations are handled this way. + if (isAMDGCN) + mathConvertionPM.addPass(mlir::createConvertMathToROCDL()); + // Convert math::FPowI operations to inline implementation // only if the exponent's width is greater than 32, otherwise, // it will be lowered to LLVM intrinsic operation by a later conversion. @@ -3710,7 +3719,8 @@ class FIRToLLVMLowering pattern); // Math operations that have not been converted yet must be converted // to Libm. - mlir::populateMathToLibmConversionPatterns(pattern); + if (!isAMDGCN) + mlir::populateMathToLibmConversionPatterns(pattern); mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern); mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/test/Lower/OpenMP/math-amdgpu.f90 b/flang/test/Lower/OpenMP/math-amdgpu.f90 new file mode 100644 index 0000000000000..b455b42d3ed34 --- /dev/null +++ b/flang/test/Lower/OpenMP/math-amdgpu.f90 @@ -0,0 +1,184 @@ +!REQUIRES: amdgpu-registered-target +!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s + +subroutine omp_pow_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_pow_f32(float {{.*}}, float {{.*}}) + y = x ** x +end subroutine omp_pow_f32 + +subroutine omp_pow_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_pow_f64(double {{.*}}, double {{.*}}) + y = x ** x +end subroutine omp_pow_f64 + +subroutine omp_sin_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_sin_f32(float {{.*}}) + y = sin(x) +end subroutine omp_sin_f32 + +subroutine omp_sin_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_sin_f64(double {{.*}}) + y = sin(x) +end subroutine omp_sin_f64 + +subroutine omp_abs_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_fabs_f32(float {{.*}}) + y = abs(x) +end subroutine omp_abs_f32 + +subroutine omp_abs_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_fabs_f64(double {{.*}}) + y = abs(x) +end subroutine omp_abs_f64 + +subroutine omp_atan_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_atan_f32(float {{.*}}) + y = atan(x) +end subroutine omp_atan_f32 + +subroutine omp_atan_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_atan_f64(double {{.*}}) + y = atan(x) +end subroutine omp_atan_f64 + +subroutine omp_atan2_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_atan2_f32(float {{.*}}, float {{.*}}) + y = atan2(x, x) +end subroutine omp_atan2_f32 + +subroutine omp_atan2_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_atan2_f64(double {{.*}}, double {{.*}}) + y = atan2(x ,x) +end subroutine omp_atan2_f64 + +subroutine omp_cos_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_cos_f32(float {{.*}}) + y = cos(x) +end subroutine omp_cos_f32 + +subroutine omp_cos_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_cos_f64(double {{.*}}) + y = cos(x) +end subroutine omp_cos_f64 + +subroutine omp_erf_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_erf_f32(float {{.*}}) + y = erf(x) +end subroutine omp_erf_f32 + +subroutine omp_erf_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_erf_f64(double {{.*}}) + y = erf(x) +end subroutine omp_erf_f64 + +subroutine omp_exp_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_exp_f32(float {{.*}}) + y = exp(x) +end subroutine omp_exp_f32 + +subroutine omp_exp_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_exp_f64(double {{.*}}) + y = exp(x) +end subroutine omp_exp_f64 + +subroutine omp_log_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_log_f32(float {{.*}}) + y = log(x) +end subroutine omp_log_f32 + +subroutine omp_log_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_log_f64(double {{.*}}) + y = log(x) +end subroutine omp_log_f64 + +subroutine omp_log10_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_log10_f32(float {{.*}}) + y = log10(x) +end subroutine omp_log10_f32 + +subroutine omp_log10_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_log10_f64(double {{.*}}) + y = log10(x) +end subroutine omp_log10_f64 + +subroutine omp_sqrt_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_sqrt_f32(float {{.*}}) + y = sqrt(x) +end subroutine omp_sqrt_f32 + +subroutine omp_sqrt_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_sqrt_f64(double {{.*}}) + y = sqrt(x) +end subroutine omp_sqrt_f64 + +subroutine omp_tan_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_tan_f32(float {{.*}}) + y = tan(x) +end subroutine omp_tan_f32 + +subroutine omp_tan_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_tan_f64(double {{.*}}) + y = tan(x) +end subroutine omp_tan_f64 + +subroutine omp_tanh_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__ocml_tanh_f32(float {{.*}}) + y = tanh(x) +end subroutine omp_tanh_f32 + +subroutine omp_tanh_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__ocml_tanh_f64(double {{.*}}) + y = tanh(x) +end subroutine omp_tanh_f64 From 1d7ec8e135ebeebea47654e22b878c7e89adf587 Mon Sep 17 00:00:00 2001 From: Jan Leyonberg Date: Thu, 5 Sep 2024 12:04:12 -0400 Subject: [PATCH 2/2] Update test to reflect chagnes in the lowering pass, where some ops are lowered to llvm intrinsics. --- flang/test/Lower/OpenMP/math-amdgpu.f90 | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flang/test/Lower/OpenMP/math-amdgpu.f90 b/flang/test/Lower/OpenMP/math-amdgpu.f90 index b455b42d3ed34..116768ba9412a 100644 --- a/flang/test/Lower/OpenMP/math-amdgpu.f90 +++ b/flang/test/Lower/OpenMP/math-amdgpu.f90 @@ -32,14 +32,14 @@ end subroutine omp_sin_f64 subroutine omp_abs_f32(x, y) !$omp declare target real :: x, y -!CHECK: call float @__ocml_fabs_f32(float {{.*}}) +!CHECK: call contract float @llvm.fabs.f32(float {{.*}}) y = abs(x) end subroutine omp_abs_f32 subroutine omp_abs_f64(x, y) !$omp declare target real(8) :: x, y -!CHECK: call double @__ocml_fabs_f64(double {{.*}}) +!CHECK: call contract double @llvm.fabs.f64(double {{.*}}) y = abs(x) end subroutine omp_abs_f64 @@ -102,7 +102,7 @@ end subroutine omp_erf_f64 subroutine omp_exp_f32(x, y) !$omp declare target real :: x, y -!CHECK: call float @__ocml_exp_f32(float {{.*}}) +!CHECK: call contract float @llvm.exp.f32(float {{.*}}) y = exp(x) end subroutine omp_exp_f32 @@ -116,7 +116,7 @@ end subroutine omp_exp_f64 subroutine omp_log_f32(x, y) !$omp declare target real :: x, y -!CHECK: call float @__ocml_log_f32(float {{.*}}) +!CHECK: call contract float @llvm.log.f32(float {{.*}}) y = log(x) end subroutine omp_log_f32 @@ -144,14 +144,14 @@ end subroutine omp_log10_f64 subroutine omp_sqrt_f32(x, y) !$omp declare target real :: x, y -!CHECK: call float @__ocml_sqrt_f32(float {{.*}}) +!CHECK: call contract float @llvm.sqrt.f32(float {{.*}}) y = sqrt(x) end subroutine omp_sqrt_f32 subroutine omp_sqrt_f64(x, y) !$omp declare target real(8) :: x, y -!CHECK: call double @__ocml_sqrt_f64(double {{.*}}) +!CHECK: call contract double @llvm.sqrt.f64(double {{.*}}) y = sqrt(x) end subroutine omp_sqrt_f64