diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 6be5548fdb60e..8ff4d4ec67b9f 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -17,11 +17,13 @@ namespace mlir { /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or -/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that -/// Op. The function declaration is added in case it was not added before. +/// `f32ApproxFunc` or `f16Func` depending on the element type and the +/// fastMathFlag of that Op. The function declaration is added in case it was +/// not added before. /// -/// If the input values are of f16 type, the value is first casted to f32, the -/// function called and then the result casted back. +/// If the input values are of bf16 type (or f16 type if f16Func is empty), the +/// value is first casted to f32, the function called and then the result casted +/// back. /// /// Example with NVVM: /// %exp_f32 = math.exp %arg_f32 : f32 @@ -41,9 +43,10 @@ template struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { public: explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func, - StringRef f64Func, StringRef f32ApproxFunc) + StringRef f64Func, StringRef f32ApproxFunc, + StringRef f16Func) : ConvertOpToLLVMPattern(lowering), f32Func(f32Func), - f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {} + f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {} LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, @@ -89,7 +92,11 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { private: Value maybeCast(Value operand, PatternRewriter &rewriter) const { Type type = operand.getType(); - if (!isa(type)) + if (!isa(type)) + return operand; + + // if there's a f16 function, no need to cast f16 values + if (!f16Func.empty() && isa(type)) return operand; return rewriter.create( @@ -102,6 +109,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { } StringRef getFunctionName(Type type, arith::FastMathFlags flag) const { + if (isa(type)) + return f16Func; if (isa(type)) { if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) && !f32ApproxFunc.empty()) @@ -130,6 +139,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { const std::string f32Func; const std::string f64Func; const std::string f32ApproxFunc; + const std::string f16Func; }; } // namespace mlir diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 4be330b0bb26b..2b91a6c28c05e 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -335,11 +335,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) { template static void populateOpPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, - StringRef f64Func, - StringRef f32ApproxFunc = "") { + StringRef f64Func, StringRef f32ApproxFunc = "", + StringRef f16Func = "") { patterns.add>(converter); patterns.add>(converter, f32Func, f64Func, - f32ApproxFunc); + f32ApproxFunc, f16Func); } void mlir::populateGpuSubgroupReduceOpLoweringPattern( diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index fc3e1fc4f9d0c..482c9e2c2d001 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -334,10 +334,9 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { target.addIllegalOp(); - // These ops are legal for f16 and f32 type. + // These ops are legal for f32 type. target.addDynamicallyLegalOp([](Operation *op) { - return any_of(op->getOperandTypes(), - llvm::IsaPred); + return any_of(op->getOperandTypes(), llvm::IsaPred); }); // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); @@ -346,9 +345,11 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { template static void populateOpPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, - StringRef f64Func) { + StringRef f64Func, StringRef f32ApproxFunc, + StringRef f16Func) { patterns.add>(converter); - patterns.add>(converter, f32Func, f64Func); + patterns.add>(converter, f32Func, f32ApproxFunc, + f16Func); } void mlir::populateGpuToROCDLConversionPatterns( diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index b3b4d81e7ffa5..8330713ea66e5 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -38,17 +38,17 @@ using namespace mlir; template static void populateOpPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, - StringRef f64Func, + StringRef f64Func, StringRef f16Func, StringRef f32ApproxFunc = "") { patterns.add>(converter); patterns.add>(converter, f32Func, f64Func, - f32ApproxFunc); + f32ApproxFunc, f16Func); } void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { // Handled by mathToLLVM: math::AbsIOp - // Handled by mathToLLVM: math::AbsFIOp + // Handled by mathToLLVM: math::AbsFOp // Handled by mathToLLVM: math::CopySignOp // Handled by mathToLLVM: math::CountLeadingZerosOp // Handled by mathToLLVM: math::CountTrailingZerosOp @@ -63,59 +63,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, // Handled by mathToLLVM: math::SqrtOp // Handled by mathToLLVM: math::TruncOp populateOpPatterns(converter, patterns, "__ocml_acos_f32", - "__ocml_acos_f64"); + "__ocml_acos_f64", "__ocml_acos_f16"); populateOpPatterns(converter, patterns, "__ocml_acosh_f32", - "__ocml_acosh_f64"); + "__ocml_acosh_f64", "__ocml_acosh_f16"); populateOpPatterns(converter, patterns, "__ocml_asin_f32", - "__ocml_asin_f64"); + "__ocml_asin_f64", "__ocml_asin_f16"); populateOpPatterns(converter, patterns, "__ocml_asinh_f32", - "__ocml_asinh_f64"); + "__ocml_asinh_f64", "__ocml_asinh_f16"); populateOpPatterns(converter, patterns, "__ocml_atan_f32", - "__ocml_atan_f64"); + "__ocml_atan_f64", "__ocml_atan_f16"); populateOpPatterns(converter, patterns, "__ocml_atanh_f32", - "__ocml_atanh_f64"); + "__ocml_atanh_f64", "__ocml_atanh_f16"); populateOpPatterns(converter, patterns, "__ocml_atan2_f32", - "__ocml_atan2_f64"); + "__ocml_atan2_f64", "__ocml_atan2_f16"); populateOpPatterns(converter, patterns, "__ocml_cbrt_f32", - "__ocml_cbrt_f64"); + "__ocml_cbrt_f64", "__ocml_cbrt_f16"); populateOpPatterns(converter, patterns, "__ocml_ceil_f32", - "__ocml_ceil_f64"); + "__ocml_ceil_f64", "__ocml_ceil_f16"); populateOpPatterns(converter, patterns, "__ocml_cos_f32", - "__ocml_cos_f64"); + "__ocml_cos_f64", "__ocml_cos_f16"); populateOpPatterns(converter, patterns, "__ocml_cosh_f32", - "__ocml_cosh_f64"); + "__ocml_cosh_f64", "__ocml_cosh_f16"); populateOpPatterns(converter, patterns, "__ocml_sinh_f32", - "__ocml_sinh_f64"); - populateOpPatterns(converter, patterns, "", "__ocml_exp_f64"); + "__ocml_sinh_f64", "__ocml_sinh_f16"); + populateOpPatterns(converter, patterns, "", "__ocml_exp_f64", + "__ocml_exp_f16"); populateOpPatterns(converter, patterns, "__ocml_exp2_f32", - "__ocml_exp2_f64"); + "__ocml_exp2_f64", "__ocml_exp2_f16"); populateOpPatterns(converter, patterns, "__ocml_expm1_f32", - "__ocml_expm1_f64"); + "__ocml_expm1_f64", "__ocml_expm1_f16"); populateOpPatterns(converter, patterns, "__ocml_floor_f32", - "__ocml_floor_f64"); - populateOpPatterns(converter, patterns, "", "__ocml_log_f64"); + "__ocml_floor_f64", "__ocml_floor_f16"); + populateOpPatterns(converter, patterns, "", "__ocml_log_f64", + "__ocml_log_f16"); populateOpPatterns(converter, patterns, "__ocml_log10_f32", - "__ocml_log10_f64"); + "__ocml_log10_f64", "__ocml_log10_f16"); populateOpPatterns(converter, patterns, "__ocml_log1p_f32", - "__ocml_log1p_f64"); + "__ocml_log1p_f64", "__ocml_log1p_f16"); populateOpPatterns(converter, patterns, "__ocml_log2_f32", - "__ocml_log2_f64"); + "__ocml_log2_f64", "__ocml_log2_f16"); populateOpPatterns(converter, patterns, "__ocml_pow_f32", - "__ocml_pow_f64"); + "__ocml_pow_f64", "__ocml_pow_f16"); populateOpPatterns(converter, patterns, "__ocml_rsqrt_f32", - "__ocml_rsqrt_f64"); + "__ocml_rsqrt_f64", "__ocml_rsqrt_f16"); populateOpPatterns(converter, patterns, "__ocml_sin_f32", - "__ocml_sin_f64"); + "__ocml_sin_f64", "__ocml_sin_f16"); populateOpPatterns(converter, patterns, "__ocml_tanh_f32", - "__ocml_tanh_f64"); + "__ocml_tanh_f64", "__ocml_tanh_f16"); populateOpPatterns(converter, patterns, "__ocml_tan_f32", - "__ocml_tan_f64"); + "__ocml_tan_f64", "__ocml_tan_f16"); populateOpPatterns(converter, patterns, "__ocml_erf_f32", - "__ocml_erf_f64"); + "__ocml_erf_f64", "__ocml_erf_f16"); // Single arith pattern that needs a ROCDL call, probably not // worth creating a separate pass for it. populateOpPatterns(converter, patterns, "__ocml_fmod_f32", - "__ocml_fmod_f64"); + "__ocml_fmod_f64", "__ocml_fmod_f16"); } namespace { diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir index eb065cbab8678..0d3e9f4ea2bf3 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -162,11 +162,12 @@ gpu.module @test_module { // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16 // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 // CHECK-LABEL: func @gpu_exp func.func @gpu_exp(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { %result16 = math.exp %arg_f16 : f16 - // CHECK: llvm.intr.exp(%{{.*}}) : (f16) -> f16 + // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16 %result32 = math.exp %arg_f32 : f32 // CHECK: llvm.intr.exp(%{{.*}}) : (f32) -> f32 %result64 = math.exp %arg_f64 : f64 @@ -178,11 +179,12 @@ gpu.module @test_module { // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_log_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log func.func @gpu_log(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { %result16 = math.log %arg_f16 : f16 - // CHECK: llvm.intr.log(%{{.*}}) : (f16) -> f16 + // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16 %result32 = math.log %arg_f32 : f32 // CHECK: llvm.intr.log(%{{.*}}) : (f32) -> f32 %result64 = math.log %arg_f64 : f64 @@ -194,108 +196,113 @@ gpu.module @test_module { // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16 // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64 // CHECK-LABEL: func @gpu_cbrt - func.func @gpu_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.cbrt %arg_f16 : f16 + // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16 %result32 = math.cbrt %arg_f32 : f32 // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32 %result64 = math.cbrt %arg_f64 : f64 // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16 // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32 // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64 // CHECK-LABEL: func @gpu_ceil - func.func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.ceil %arg_f16 : f16 + // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16 %result32 = math.ceil %arg_f32 : f32 // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32 %result64 = math.ceil %arg_f64 : f64 // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16 // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32 // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64 // CHECK-LABEL: func @gpu_floor - func.func @gpu_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.floor %arg_f16 : f16 + // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16 %result32 = math.floor %arg_f32 : f32 // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32 %result64 = math.floor %arg_f64 : f64 // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16 // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64 // CHECK-LABEL: func @gpu_cos - func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.cos %arg_f16 : f16 + // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16 %result32 = math.cos %arg_f32 : f32 // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32 %result64 = math.cos %arg_f64 : f64 // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 - } -} - -// ----- - -gpu.module @test_module { - // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 - // CHECK-LABEL: func @gpu_exp - func.func @gpu_exp(%arg_f64 : f64) -> (f64) { - %result64 = math.exp %arg_f64 : f64 - // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 - func.return %result64 : f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_exp2_f16(f16) -> f16 // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32 // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64 // CHECK-LABEL: func @gpu_exp2 - func.func @gpu_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.exp2 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16 %exp2_f32 = math.exp2 %arg_f32 : f32 // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32 %result32 = math.exp2 %exp2_f32 : f32 // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32 %result64 = math.exp2 %arg_f64 : f64 // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- + // Test that we handled properly operation with SymbolTable other than module op gpu.module @test_module { "test.symbol_scope"() ({ // CHECK: test.symbol_scope + // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32 // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64 // CHECK-LABEL: func @gpu_sin - func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %sin_f32 = math.sin %arg_f32 : f32 + func.func @gpu_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + %result16 = math.sin %arg_f16 : f16 // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 - %result32 = math.sin %sin_f32 : f32 - // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.sin %arg_f64 : f64 + %result32 = math.sin %arg_f32 : f32 // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + %result64 = math.sin %arg_f64 : f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } "test.finish" () : () -> () }) : () -> () @@ -304,89 +311,102 @@ gpu.module @test_module { // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_expm1_f16(f16) -> f16 // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32 // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64 // CHECK-LABEL: func @gpu_expm1 - func.func @gpu_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.expm1 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16 %expm1_f32 = math.expm1 %arg_f32 : f32 // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32 %result32 = math.expm1 %expm1_f32 : f32 // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32 %result64 = math.expm1 %arg_f64 : f64 // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_log_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log - func.func @gpu_log(%arg_f64 : f64) -> (f64) { + func.func @gpu_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) { + %result16 = math.log %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16 %result64 = math.log %arg_f64 : f64 // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 - func.return %result64 : f64 + func.return %result16, %result64 : f16, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_log1p_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log1p - func.func @gpu_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.log1p %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16 %result32 = math.log1p %arg_f32 : f32 // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32 %result64 = math.log1p %arg_f64 : f64 // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_log10_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log10 - func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.log10 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16 %result32 = math.log10 %arg_f32 : f32 // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32 %result64 = math.log10 %arg_f64 : f64 // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_log2_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log2_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log2_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log2 - func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_log2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.log2 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log2_f16(%{{.*}}) : (f16) -> f16 %result32 = math.log2 %arg_f32 : f32 // CHECK: llvm.call @__ocml_log2_f32(%{{.*}}) : (f32) -> f32 %result64 = math.log2 %arg_f64 : f64 // CHECK: llvm.call @__ocml_log2_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_rsqrt_f16(f16) -> f16 // CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64 // CHECK-LABEL: func @gpu_rsqrt - func.func @gpu_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) - -> (f16, f32, f64) { + func.func @gpu_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { %result16 = math.rsqrt %arg_f16 : f16 - // CHECK: llvm.fpext %{{.*}} : f16 to f32 - // CHECK-NEXT: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 - // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16 + // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16 %result32 = math.rsqrt %arg_f32 : f32 // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 %result64 = math.rsqrt %arg_f64 : f64 @@ -398,90 +418,108 @@ gpu.module @test_module { // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_tan_f16(f16) -> f16 // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64 // CHECK-LABEL: func @gpu_tan - func.func @gpu_tan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.tan %arg_f16 : f16 + // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16 %result32 = math.tan %arg_f32 : f32 // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32 %result64 = math.tan %arg_f64 : f64 // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_tanh_f16(f16) -> f16 // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64 // CHECK-LABEL: func @gpu_tanh - func.func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.tanh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16 %result32 = math.tanh %arg_f32 : f32 // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32 %result64 = math.tanh %arg_f64 : f64 // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_atan_f16(f16) -> f16 // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32 // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64 // CHECK-LABEL: func @gpu_atan - func.func @gpu_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.atan %arg_f16 : f16 + // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16 %result32 = math.atan %arg_f32 : f32 // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32 %result64 = math.atan %arg_f64 : f64 // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_atan2_f16(f16, f16) -> f16 // CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64 // CHECK-LABEL: func @gpu_atan2 - func.func @gpu_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.atan2 %arg_f16, %arg_f16 : f16 + // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}) : (f16, f16) -> f16 %result32 = math.atan2 %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}) : (f32, f32) -> f32 %result64 = math.atan2 %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}) : (f64, f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_pow_f16(f16, f16) -> f16 // CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64 // CHECK-LABEL: func @gpu_pow - func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_pow(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.powf %arg_f16, %arg_f16 : f16 + // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 %result32 = math.powf %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 %result64 = math.powf %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_erf_f16(f16) -> f16 // CHECK: llvm.func @__ocml_erf_f32(f32) -> f32 // CHECK: llvm.func @__ocml_erf_f64(f64) -> f64 // CHECK-LABEL: func @gpu_erf - func.func @gpu_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.erf %arg_f16 : f16 + // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16 %result32 = math.erf %arg_f32 : f32 // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32 %result64 = math.erf %arg_f64 : f64 // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } @@ -543,9 +581,9 @@ gpu.module @test_module { // ----- gpu.module @module { -// CHECK-LABEL: @spirv_exp +// CHECK-LABEL: @spirv_sin // CHECK: llvm.call @__ocml_sin_f32 - spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" { + spirv.func @spirv_sin(%arg0: vector<4xf32>) -> vector<4xf32> "None" { %0 = math.sin %arg0 : vector<4xf32> spirv.ReturnValue %0 : vector<4xf32> } @@ -602,15 +640,18 @@ gpu.module @test_module { // ----- gpu.module @test_module { + // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16 // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64 // CHECK-LABEL: func @gpu_fmod - func.func @gpu_fmod(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @gpu_fmod(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = arith.remf %arg_f16, %arg_f16 : f16 + // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 %result32 = arith.remf %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 %result64 = arith.remf %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index 19d89e03a7f48..ddd96bf797e6e 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -1,399 +1,483 @@ // RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s module @test_module { + // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16 // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64 // CHECK-LABEL: func @arith_remf - func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @arith_remf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = arith.remf %arg_f16, %arg_f16 : f16 + // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 %result32 = arith.remf %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 %result64 = arith.remf %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_acos_f16(f16) -> f16 // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32 // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64 // CHECK-LABEL: func @math_acos - func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_acos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.acos %arg_f16 : f16 + // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16 %result32 = math.acos %arg_f32 : f32 // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32 %result64 = math.acos %arg_f64 : f64 // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_acosh_f16(f16) -> f16 // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64 // CHECK-LABEL: func @math_acosh - func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_acosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.acosh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16 %result32 = math.acosh %arg_f32 : f32 // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32 %result64 = math.acosh %arg_f64 : f64 // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_asin_f16(f16) -> f16 // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32 // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64 // CHECK-LABEL: func @math_asin - func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_asin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.asin %arg_f16 : f16 + // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16 %result32 = math.asin %arg_f32 : f32 // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32 %result64 = math.asin %arg_f64 : f64 // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_asinh_f16(f16) -> f16 // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64 // CHECK-LABEL: func @math_asinh - func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_asinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.asinh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16 %result32 = math.asinh %arg_f32 : f32 // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32 %result64 = math.asinh %arg_f64 : f64 // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_atan_f16(f16) -> f16 // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32 // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64 // CHECK-LABEL: func @math_atan - func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.atan %arg_f16 : f16 + // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16 %result32 = math.atan %arg_f32 : f32 // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32 %result64 = math.atan %arg_f64 : f64 // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_atanh_f16(f16) -> f16 // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64 // CHECK-LABEL: func @math_atanh - func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.atanh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16 %result32 = math.atanh %arg_f32 : f32 // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32 %result64 = math.atanh %arg_f64 : f64 // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_atan2_f16(f16, f16) -> f16 // CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64 // CHECK-LABEL: func @math_atan2 - func.func @math_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.atan2 %arg_f16, %arg_f16 : f16 + // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 %result32 = math.atan2 %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 %result64 = math.atan2 %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16 // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64 // CHECK-LABEL: func @math_cbrt - func.func @math_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.cbrt %arg_f16 : f16 + // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16 %result32 = math.cbrt %arg_f32 : f32 // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32 %result64 = math.cbrt %arg_f64 : f64 // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16 // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32 // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64 // CHECK-LABEL: func @math_ceil - func.func @math_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.ceil %arg_f16 : f16 + // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16 %result32 = math.ceil %arg_f32 : f32 // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32 %result64 = math.ceil %arg_f64 : f64 // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16 // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64 // CHECK-LABEL: func @math_cos - func.func @math_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.cos %arg_f16 : f16 + // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16 %result32 = math.cos %arg_f32 : f32 // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32 %result64 = math.cos %arg_f64 : f64 // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_cosh_f16(f16) -> f16 // CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64 // CHECK-LABEL: func @math_cosh - func.func @math_cosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_cosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.cosh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16 %result32 = math.cosh %arg_f32 : f32 // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32 %result64 = math.cosh %arg_f64 : f64 // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_sinh_f16(f16) -> f16 // CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64 // CHECK-LABEL: func @math_sinh - func.func @math_sinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_sinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.sinh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16 %result32 = math.sinh %arg_f32 : f32 // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32 %result64 = math.sinh %arg_f64 : f64 // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16 // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 // CHECK-LABEL: func @math_exp - func.func @math_exp(%arg_f64 : f64) -> (f64) { + func.func @math_exp(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) { + %result16 = math.exp %arg_f16 : f16 + // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16 %result64 = math.exp %arg_f64 : f64 // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 - func.return %result64 : f64 + func.return %result16, %result64 : f16, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_exp2_f16(f16) -> f16 // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32 // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64 // CHECK-LABEL: func @math_exp2 - func.func @math_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.exp2 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16 %result32 = math.exp2 %arg_f32 : f32 // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32 %result64 = math.exp2 %arg_f64 : f64 // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_expm1_f16(f16) -> f16 // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32 // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64 // CHECK-LABEL: func @math_expm1 - func.func @math_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.expm1 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16 %result32 = math.expm1 %arg_f32 : f32 // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32 %result64 = math.expm1 %arg_f64 : f64 // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16 // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32 // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64 // CHECK-LABEL: func @math_floor - func.func @math_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.floor %arg_f16 : f16 + // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16 %result32 = math.floor %arg_f32 : f32 // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32 %result64 = math.floor %arg_f64 : f64 // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_log_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 // CHECK-LABEL: func @math_log - func.func @math_log(%arg_f64 : f64) -> (f64) { + func.func @math_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) { + %result16 = math.log %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16 %result64 = math.log %arg_f64 : f64 // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 - func.return %result64 : f64 + func.return %result16, %result64 : f16, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_log10_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64 // CHECK-LABEL: func @math_log10 - func.func @math_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.log10 %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16 %result32 = math.log10 %arg_f32 : f32 // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32 %result64 = math.log10 %arg_f64 : f64 // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_log1p_f16(f16) -> f16 // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64 // CHECK-LABEL: func @math_log1p - func.func @math_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.log1p %arg_f16 : f16 + // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16 %result32 = math.log1p %arg_f32 : f32 // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32 %result64 = math.log1p %arg_f64 : f64 // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_pow_f16(f16, f16) -> f16 // CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32 // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64 // CHECK-LABEL: func @math_powf - func.func @math_powf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_powf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.powf %arg_f16, %arg_f16 : f16 + // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16 %result32 = math.powf %arg_f32, %arg_f32 : f32 // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 %result64 = math.powf %arg_f64, %arg_f64 : f64 // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_rsqrt_f16(f16) -> f16 // CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64 // CHECK-LABEL: func @math_rsqrt - func.func @math_rsqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.rsqrt %arg_f16 : f16 + // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16 %result32 = math.rsqrt %arg_f32 : f32 // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 %result64 = math.rsqrt %arg_f64 : f64 // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32 // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64 // CHECK-LABEL: func @math_sin - func.func @math_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.sin %arg_f16 : f16 + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 %result32 = math.sin %arg_f32 : f32 // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 %result64 = math.sin %arg_f64 : f64 // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_tanh_f16(f16) -> f16 // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64 // CHECK-LABEL: func @math_tanh - func.func @math_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.tanh %arg_f16 : f16 + // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16 %result32 = math.tanh %arg_f32 : f32 // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32 %result64 = math.tanh %arg_f64 : f64 // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_tan_f16(f16) -> f16 // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64 // CHECK-LABEL: func @math_tan - func.func @math_tan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.tan %arg_f16 : f16 + // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16 %result32 = math.tan %arg_f32 : f32 // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32 %result64 = math.tan %arg_f64 : f64 // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { + // CHECK: llvm.func @__ocml_erf_f16(f16) -> f16 // CHECK: llvm.func @__ocml_erf_f32(f32) -> f32 // CHECK: llvm.func @__ocml_erf_f64(f64) -> f64 // CHECK-LABEL: func @math_erf - func.func @math_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func.func @math_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = math.erf %arg_f16 : f16 + // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16 %result32 = math.erf %arg_f32 : f32 // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32 %result64 = math.erf %arg_f64 : f64 // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result16, %result32, %result64 : f16, f32, f64 } } // ----- module @test_module { - // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32 - // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64 - // CHECK-LABEL: func @arith_remf - func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = arith.remf %arg_f32, %arg_f32 : f32 - // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 - %result64 = arith.remf %arg_f64, %arg_f64 : f64 - // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 - func.return %result32, %result64 : f32, f64 + // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 + // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64 + // CHECK-LABEL: func @math_casting + func.func @math_casting(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64, %arg_bf16 : bf16) -> (f16, f32, f64, bf16) { + %resultf16 = math.sin %arg_f16 : f16 + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + %resultf32 = math.sin %arg_f32 : f32 + // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + %resultf64 = math.sin %arg_f64 : f64 + // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 + %resultbf16 = math.sin %arg_bf16 : bf16 + // CHECK: llvm.fpext %{{.*}} : bf16 to f32 + // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16 + func.return %resultf16, %resultf32, %resultf64, %resultbf16 : f16, f32, f64, bf16 } } -