Skip to content

[flang][AMDGPU] Convert math ops to AMD GPU library calls instead of libm calls #99517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/lib/Optimizer/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
MLIRMathToFuncs
MLIRMathToLLVM
MLIRMathToLibm
MLIRMathToROCDL
MLIROpenMPToLLVM
MLIROpenACCDialect
MLIRBuiltinToLLVMIRTranslation
Expand Down
12 changes: 11 additions & 1 deletion flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -3671,6 +3672,14 @@ class FIRToLLVMLowering
// as passes here.
mlir::OpPassManager mathConvertionPM("builtin.module");

bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
// If compiling for AMD target some math operations must be lowered to AMD
// GPU library calls, the rest can be converted to LLVM intrinsics, which
// is handled in the mathToLLVM conversion. The lowering to libm calls is
// not needed since all math operations are handled this way.
if (isAMDGCN)
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());

// Convert math::FPowI operations to inline implementation
// only if the exponent's width is greater than 32, otherwise,
// it will be lowered to LLVM intrinsic operation by a later conversion.
Expand Down Expand Up @@ -3710,7 +3719,8 @@ class FIRToLLVMLowering
pattern);
// Math operations that have not been converted yet must be converted
// to Libm.
mlir::populateMathToLibmConversionPatterns(pattern);
if (!isAMDGCN)
mlir::populateMathToLibmConversionPatterns(pattern);
mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);

Expand Down
184 changes: 184 additions & 0 deletions flang/test/Lower/OpenMP/math-amdgpu.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
!REQUIRES: amdgpu-registered-target
!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s

subroutine omp_pow_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_pow_f32(float {{.*}}, float {{.*}})
y = x ** x
end subroutine omp_pow_f32

subroutine omp_pow_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_pow_f64(double {{.*}}, double {{.*}})
y = x ** x
end subroutine omp_pow_f64

subroutine omp_sin_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_sin_f32(float {{.*}})
y = sin(x)
end subroutine omp_sin_f32

subroutine omp_sin_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_sin_f64(double {{.*}})
y = sin(x)
end subroutine omp_sin_f64

subroutine omp_abs_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call contract float @llvm.fabs.f32(float {{.*}})
y = abs(x)
end subroutine omp_abs_f32

subroutine omp_abs_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call contract double @llvm.fabs.f64(double {{.*}})
y = abs(x)
end subroutine omp_abs_f64

subroutine omp_atan_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_atan_f32(float {{.*}})
y = atan(x)
end subroutine omp_atan_f32

subroutine omp_atan_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_atan_f64(double {{.*}})
y = atan(x)
end subroutine omp_atan_f64

subroutine omp_atan2_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_atan2_f32(float {{.*}}, float {{.*}})
y = atan2(x, x)
end subroutine omp_atan2_f32

subroutine omp_atan2_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_atan2_f64(double {{.*}}, double {{.*}})
y = atan2(x ,x)
end subroutine omp_atan2_f64

subroutine omp_cos_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_cos_f32(float {{.*}})
y = cos(x)
end subroutine omp_cos_f32

subroutine omp_cos_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_cos_f64(double {{.*}})
y = cos(x)
end subroutine omp_cos_f64

subroutine omp_erf_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_erf_f32(float {{.*}})
y = erf(x)
end subroutine omp_erf_f32

subroutine omp_erf_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_erf_f64(double {{.*}})
y = erf(x)
end subroutine omp_erf_f64

subroutine omp_exp_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call contract float @llvm.exp.f32(float {{.*}})
y = exp(x)
end subroutine omp_exp_f32

subroutine omp_exp_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_exp_f64(double {{.*}})
y = exp(x)
end subroutine omp_exp_f64

subroutine omp_log_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call contract float @llvm.log.f32(float {{.*}})
y = log(x)
end subroutine omp_log_f32

subroutine omp_log_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_log_f64(double {{.*}})
y = log(x)
end subroutine omp_log_f64

subroutine omp_log10_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_log10_f32(float {{.*}})
y = log10(x)
end subroutine omp_log10_f32

subroutine omp_log10_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_log10_f64(double {{.*}})
y = log10(x)
end subroutine omp_log10_f64

subroutine omp_sqrt_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call contract float @llvm.sqrt.f32(float {{.*}})
y = sqrt(x)
end subroutine omp_sqrt_f32

subroutine omp_sqrt_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call contract double @llvm.sqrt.f64(double {{.*}})
y = sqrt(x)
end subroutine omp_sqrt_f64

subroutine omp_tan_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_tan_f32(float {{.*}})
y = tan(x)
end subroutine omp_tan_f32

subroutine omp_tan_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_tan_f64(double {{.*}})
y = tan(x)
end subroutine omp_tan_f64

subroutine omp_tanh_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_tanh_f32(float {{.*}})
y = tanh(x)
end subroutine omp_tanh_f32

subroutine omp_tanh_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_tanh_f64(double {{.*}})
y = tanh(x)
end subroutine omp_tanh_f64
Loading