Skip to content

Commit 279a659

Browse files
authored
[mlir][math] lower rsqrt to sqrt + fdiv (#91344)
This commit creates an expansion pattern to lower math.rsqrt(x) into fdiv(1, sqrt(x)).
1 parent ae2a18d commit 279a659

File tree

5 files changed

+134
-0
lines changed

5 files changed

+134
-0
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ void populateExpandPowFPattern(RewritePatternSet &patterns);
4242
void populateExpandFPowIPattern(RewritePatternSet &patterns);
4343
void populateExpandRoundFPattern(RewritePatternSet &patterns);
4444
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
45+
void populateExpandRsqrtPattern(RewritePatternSet &patterns);
4546
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
4647

4748
struct MathPolynomialApproximationOptions {

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,23 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
615615
return success();
616616
}
617617

618+
// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
619+
static LogicalResult convertRsqrtOp(math::RsqrtOp op,
620+
PatternRewriter &rewriter) {
621+
622+
auto operand = op.getOperand();
623+
auto operandTy = operand.getType();
624+
auto eTy = getElementTypeOrSelf(operandTy);
625+
if (!isa<FloatType>(eTy))
626+
return failure();
627+
628+
Location loc = op->getLoc();
629+
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
630+
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, operand);
631+
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
632+
return success();
633+
}
634+
618635
void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
619636
patterns.add(convertCtlzOp);
620637
}
@@ -678,3 +695,7 @@ void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
678695
void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
679696
patterns.add(convertRoundEvenOp);
680697
}
698+
699+
void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
700+
patterns.add(convertRsqrtOp);
701+
}

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,3 +658,73 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
658658
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
659659
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
660660
// CHECK: return %[[SEL]] : f32
661+
662+
// -----
663+
664+
// CHECK-LABEL: func.func @rsqrt
665+
// CHECK-SAME: (%[[ARG:.*]]: f16)
666+
// CHECK-SAME: -> f16
667+
// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f16
668+
// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : f16
669+
// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f16
670+
// CHECK: return %[[DIV]] : f16
671+
func.func @rsqrt16(%float: f16) -> (f16) {
672+
%float_result = math.rsqrt %float : f16
673+
return %float_result : f16
674+
}
675+
676+
// -----
677+
678+
// CHECK-LABEL: func.func @rsqrt
679+
// CHECK-SAME: (%[[ARG:.*]]: f32)
680+
// CHECK-SAME: -> f32
681+
// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
682+
// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : f32
683+
// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f32
684+
// CHECK: return %[[DIV]] : f32
685+
func.func @rsqrt32(%float: f32) -> (f32) {
686+
%float_result = math.rsqrt %float : f32
687+
return %float_result : f32
688+
}
689+
690+
// -----
691+
692+
// CHECK-LABEL: func.func @rsqrt
693+
// CHECK-SAME: (%[[ARG:.*]]: f64)
694+
// CHECK-SAME: -> f64
695+
// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f64
696+
// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : f64
697+
// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f64
698+
// CHECK: return %[[DIV]] : f64
699+
func.func @rsqrt64(%float: f64) -> (f64) {
700+
%float_result = math.rsqrt %float : f64
701+
return %float_result : f64
702+
}
703+
704+
// -----
705+
706+
// CHECK-LABEL: func.func @rsqrt_vec
707+
// CHECK-SAME: (%[[ARG:.*]]: vector<5xf32>)
708+
// CHECK-SAME: -> vector<5xf32>
709+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<5xf32>
710+
// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : vector<5xf32>
711+
// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : vector<5xf32>
712+
// CHECK: return %[[DIV]] : vector<5xf32>
713+
func.func @rsqrt_vec(%float: vector<5xf32>) -> (vector<5xf32>) {
714+
%float_result = math.rsqrt %float : vector<5xf32>
715+
return %float_result : vector<5xf32>
716+
}
717+
718+
// -----
719+
720+
// CHECK-LABEL: func.func @rsqrt_tns
721+
// CHECK-SAME: (%[[ARG:.*]]: tensor<5x8xf32>)
722+
// CHECK-SAME: -> tensor<5x8xf32>
723+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<5x8xf32>
724+
// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : tensor<5x8xf32>
725+
// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : tensor<5x8xf32>
726+
// CHECK: return %[[DIV]] : tensor<5x8xf32>
727+
func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
728+
%float_result = math.rsqrt %float : tensor<5x8xf32>
729+
return %float_result : tensor<5x8xf32>
730+
}

mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ void TestExpandMathPass::runOnOperation() {
5252
populateExpandFPowIPattern(patterns);
5353
populateExpandRoundFPattern(patterns);
5454
populateExpandRoundEvenPattern(patterns);
55+
populateExpandRsqrtPattern(patterns);
5556
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
5657
}
5758

mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,46 @@ func.func @atanh() {
833833
return
834834
}
835835

836+
// -------------------------------------------------------------------------- //
837+
// Rsqrt.
838+
// -------------------------------------------------------------------------- //
839+
840+
func.func @rsqrt_f32(%a : f32) {
841+
%r = math.rsqrt %a : f32
842+
vector.print %r : f32
843+
return
844+
}
845+
846+
func.func @rsqrt_3xf32(%a : vector<3xf32>) {
847+
%r = math.rsqrt %a : vector<3xf32>
848+
vector.print %r : vector<3xf32>
849+
return
850+
}
851+
852+
func.func @rsqrt() {
853+
// CHECK: 1
854+
%zero = arith.constant 1.0 : f32
855+
call @rsqrt_f32(%zero) : (f32) -> ()
856+
857+
// CHECK: 0.707107
858+
%cst1 = arith.constant 2.0 : f32
859+
call @rsqrt_f32(%cst1) : (f32) -> ()
860+
861+
// CHECK: inf
862+
%cst2 = arith.constant 0.0 : f32
863+
call @rsqrt_f32(%cst2) : (f32) -> ()
864+
865+
// CHECK: -nan
866+
%cst3 = arith.constant -1.0 : f32
867+
call @rsqrt_f32(%cst3) : (f32) -> ()
868+
869+
// CHECK: 0.5, 1.41421, 0.57735
870+
%vec_x = arith.constant dense<[4.0, 0.5, 3.0]> : vector<3xf32>
871+
call @rsqrt_3xf32(%vec_x) : (vector<3xf32>) -> ()
872+
873+
return
874+
}
875+
836876
func.func @main() {
837877
call @exp2f() : () -> ()
838878
call @roundf() : () -> ()
@@ -844,5 +884,6 @@ func.func @main() {
844884
call @asinh() : () -> ()
845885
call @acosh() : () -> ()
846886
call @atanh() : () -> ()
887+
call @rsqrt() : () -> ()
847888
return
848889
}

0 commit comments

Comments
 (0)