diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index e2c513047c77a..24e6d9a8d98e0 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -31,6 +31,9 @@ void populateExpandTanPattern(RewritePatternSet &patterns); void populateExpandSinhPattern(RewritePatternSet &patterns); void populateExpandCoshPattern(RewritePatternSet &patterns); void populateExpandTanhPattern(RewritePatternSet &patterns); +void populateExpandAsinhPattern(RewritePatternSet &patterns); +void populateExpandAcoshPattern(RewritePatternSet &patterns); +void populateExpandAtanhPattern(RewritePatternSet &patterns); void populateExpandFmaFPattern(RewritePatternSet &patterns); void populateExpandFloorFPattern(RewritePatternSet &patterns); void populateExpandCeilFPattern(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 42629e149e9ff..5ccf3b6d72a2c 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type opType = operand.getType(); - Value exp = b.create(operand); - Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); - Value nexp = b.create(one, exp); + Value exp = b.create(operand); + Value neg = b.create(operand); + Value nexp = b.create(neg); Value sub = b.create(exp, nexp); - Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter); - Value div = b.create(sub, two); - rewriter.replaceOp(op, div); + Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); + Value res = b.create(sub, half); + rewriter.replaceOp(op, res); return success(); } @@ -89,14 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type opType = operand.getType(); - Value exp = b.create(operand); - Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); - Value nexp = b.create(one, exp); + Value exp = b.create(operand); + Value neg = b.create(operand); + Value nexp = b.create(neg); Value add = b.create(exp, nexp); - Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter); - Value div = b.create(add, two); - rewriter.replaceOp(op, div); + Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); + Value res = b.create(add, half); + rewriter.replaceOp(op, res); return success(); } @@ -152,6 +152,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { return success(); } +// asinh(float x) -> log(x + sqrt(x**2 + 1)) +static LogicalResult convertAsinhOp(math::AsinhOp op, + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operand = op.getOperand(); + Type opType = operand.getType(); + + Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); + Value fma = b.create(operand, operand, one); + Value sqrt = b.create(fma); + Value add = b.create(operand, sqrt); + Value res = b.create(add); + rewriter.replaceOp(op, res); + return success(); +} + +// acosh(float x) -> log(x + sqrt(x**2 - 1)) +static LogicalResult convertAcoshOp(math::AcoshOp op, + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operand = op.getOperand(); + Type opType = operand.getType(); + + Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter); + Value fma = b.create(operand, operand, negOne); + Value sqrt = b.create(fma); + Value add = b.create(operand, sqrt); + Value res = b.create(add); + rewriter.replaceOp(op, res); + return success(); +} + +// atanh(float x) -> log((1 + x) / (1 - x)) / 2 +static LogicalResult convertAtanhOp(math::AtanhOp op, + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operand = op.getOperand(); + Type opType = operand.getType(); + + Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); + Value add = b.create(operand, one); + Value neg = b.create(operand); + Value sub = b.create(neg, one); + Value div = b.create(add, sub); + Value log = b.create(div); + Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); + Value res = b.create(log, half); + rewriter.replaceOp(op, res); + return success(); +} + static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operandA = op.getOperand(0); @@ -584,6 +635,18 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { patterns.add(convertTanhOp); } +void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) { + patterns.add(convertAsinhOp); +} + +void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) { + patterns.add(convertAcoshOp); +} + +void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) { + patterns.add(convertAtanhOp); +} + void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { patterns.add(convertFmaFOp); } diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp index 97600ad1ebe7a..da48ccb6e5e08 100644 --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -42,6 +42,9 @@ void TestExpandMathPass::runOnOperation() { populateExpandSinhPattern(patterns); populateExpandCoshPattern(patterns); populateExpandTanhPattern(patterns); + populateExpandAsinhPattern(patterns); + populateExpandAcoshPattern(patterns); + populateExpandAtanhPattern(patterns); populateExpandFmaFPattern(patterns); populateExpandFloorFPattern(patterns); populateExpandCeilFPattern(patterns); diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir index 340ef30bf59c2..2b72acde6a3bb 100644 --- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -717,6 +717,122 @@ func.func @tanh() { return } +// -------------------------------------------------------------------------- // +// Asinh. +// -------------------------------------------------------------------------- // + +func.func @asinh_f32(%a : f32) { + %r = math.asinh %a : f32 + vector.print %r : f32 + return +} + +func.func @asinh_3xf32(%a : vector<3xf32>) { + %r = math.asinh %a : vector<3xf32> + vector.print %r : vector<3xf32> + return +} + +func.func @asinh() { + // CHECK: 0 + %zero = arith.constant 0.0 : f32 + call @asinh_f32(%zero) : (f32) -> () + + // CHECK: 0.881374 + %cst1 = arith.constant 1.0 : f32 + call @asinh_f32(%cst1) : (f32) -> () + + // CHECK: -0.881374 + %cst2 = arith.constant -1.0 : f32 + call @asinh_f32(%cst2) : (f32) -> () + + // CHECK: 1.81845 + %cst3 = arith.constant 3.0 : f32 + call @asinh_f32(%cst3) : (f32) -> () + + // CHECK: 0.247466, 0.790169, 1.44364 + %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32> + call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> () + + return +} + +// -------------------------------------------------------------------------- // +// Acosh. +// -------------------------------------------------------------------------- // + +func.func @acosh_f32(%a : f32) { + %r = math.acosh %a : f32 + vector.print %r : f32 + return +} + +func.func @acosh_3xf32(%a : vector<3xf32>) { + %r = math.acosh %a : vector<3xf32> + vector.print %r : vector<3xf32> + return +} + +func.func @acosh() { + // CHECK: 0 + %zero = arith.constant 1.0 : f32 + call @acosh_f32(%zero) : (f32) -> () + + // CHECK: 1.31696 + %cst1 = arith.constant 2.0 : f32 + call @acosh_f32(%cst1) : (f32) -> () + + // CHECK: 2.99322 + %cst2 = arith.constant 10.0 : f32 + call @acosh_f32(%cst2) : (f32) -> () + + // CHECK: 0.962424, 1.76275, 2.47789 + %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32> + call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> () + + return +} + +// -------------------------------------------------------------------------- // +// Atanh. +// -------------------------------------------------------------------------- // + +func.func @atanh_f32(%a : f32) { + %r = math.atanh %a : f32 + vector.print %r : f32 + return +} + +func.func @atanh_3xf32(%a : vector<3xf32>) { + %r = math.atanh %a : vector<3xf32> + vector.print %r : vector<3xf32> + return +} + +func.func @atanh() { + // CHECK: 0 + %zero = arith.constant 0.0 : f32 + call @atanh_f32(%zero) : (f32) -> () + + // CHECK: 0.549306 + %cst1 = arith.constant 0.5 : f32 + call @atanh_f32(%cst1) : (f32) -> () + + // CHECK: -0.549306 + %cst2 = arith.constant -0.5 : f32 + call @atanh_f32(%cst2) : (f32) -> () + + // CHECK: inf + %cst3 = arith.constant 1.0 : f32 + call @atanh_f32(%cst3) : (f32) -> () + + // CHECK: 0.255413, 0.394229, 2.99448 + %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32> + call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> () + + return +} + func.func @main() { call @exp2f() : () -> () call @roundf() : () -> () @@ -725,5 +841,8 @@ func.func @main() { call @sinh() : () -> () call @cosh() : () -> () call @tanh() : () -> () + call @asinh() : () -> () + call @acosh() : () -> () + call @atanh() : () -> () return }