From 4025daca65ac2ee4527cd4c4925486145693ed9e Mon Sep 17 00:00:00 2001 From: Frederik Harwath Date: Mon, 4 Dec 2023 23:48:31 -0800 Subject: [PATCH] Implement acos operator in MLIR Math Dialect Required for torch-mlir. Cf. llvm/torch-mlir#2604 "Implement torch.aten.acos". --- mlir/include/mlir/Dialect/Math/IR/MathOps.td | 29 ++++++++++++++ mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 1 + mlir/lib/Dialect/Math/IR/MathOps.cpp | 18 +++++++++ .../MathToLibm/convert-to-libm.mlir | 39 +++++++++++++++++++ 4 files changed, 87 insertions(+) diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index f8e9fd601304b..9742d3d936dff 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -300,6 +300,35 @@ def Math_CosOp : Math_FloatUnaryOp<"cos"> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// AcosOp +//===----------------------------------------------------------------------===// + +def Math_AcosOp : Math_FloatUnaryOp<"acos"> { + let summary = "arcus cosine of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.acos` ssa-use `:` type + ``` + + The `acos` operation computes the arcus cosine of a given value. It takes one + operand of floating point type (i.e., scalar, tensor or vector) and returns one + result of the same type. It has no standard attributes. + + Example: + + ```mlir + // Scalar arcus cosine value. + %a = math.acos %b : f64 + ``` + }]; + let hasFolder = 1; +} + + + //===----------------------------------------------------------------------===// // SinOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index 103c1fb8c3822..27c2cb9352071 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -162,6 +162,7 @@ ScalarOpToLibmCall::matchAndRewrite(Op op, void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); + populatePatternsForOp(patterns, ctx, "acosf", "acos"); populatePatternsForOp(patterns, ctx, "atan2f", "atan2"); populatePatternsForOp(patterns, ctx, "atanf", "atan"); populatePatternsForOp(patterns, ctx, "cbrtf", "cbrt"); diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index 28d1c062f235e..066a21c76f7d1 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -41,6 +41,24 @@ OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) { [](const APInt &a) { return a.abs(); }); } +//===----------------------------------------------------------------------===// +// AcosOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOpConditional( + adaptor.getOperands(), [](const APFloat &a) -> std::optional { + switch (a.getSizeInBits(a.getSemantics())) { + case 64: + return APFloat(acos(a.convertToDouble())); + case 32: + return APFloat(acosf(a.convertToFloat())); + default: + return {}; + } + }); +} + //===----------------------------------------------------------------------===// // AtanOp folder //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir index 4837a0cce634a..f0c4512cbfdcc 100644 --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s +// CHECK-DAG: @acos(f64) -> f64 attributes {llvm.readnone} +// CHECK-DAG: @acosf(f32) -> f32 attributes {llvm.readnone} // CHECK-DAG: @atan(f64) -> f64 attributes {llvm.readnone} // CHECK-DAG: @atanf(f32) -> f32 attributes {llvm.readnone} // CHECK-DAG: @erf(f64) -> f64 attributes {llvm.readnone} @@ -29,6 +31,43 @@ // CHECK-DAG: @ceil(f64) -> f64 attributes {llvm.readnone} // CHECK-DAG: @ceilf(f32) -> f32 attributes {llvm.readnone} +// CHECK-LABEL: func @acos_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @acos_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @acosf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.acos %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @acos(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.acos %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + +// CHECK-LABEL: func @acos_vec_caller( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { +// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> +// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64> +// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32> +// CHECK: %[[OUT0_F32:.*]] = call @acosf(%[[IN0_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32> +// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32> +// CHECK: %[[OUT1_F32:.*]] = call @acosf(%[[IN1_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32> +// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64> +// CHECK: %[[OUT0_F64:.*]] = call @acos(%[[IN0_F64]]) : (f64) -> f64 +// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64> +// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64> +// CHECK: %[[OUT1_F64:.*]] = call @acos(%[[IN1_F64]]) : (f64) -> f64 +// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64> +// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> +// CHECK: } +func.func @acos_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { + %float_result = math.acos %float : vector<2xf32> + %double_result = math.acos %double : vector<2xf64> + return %float_result, %double_result : vector<2xf32>, vector<2xf64> +} + // CHECK-LABEL: func @atan_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64