diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 3ebee9baff31b..49eb575212ffc 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -27,11 +27,9 @@ using namespace mlir; namespace { -enum class AbsFn { abs, sqrt, rsqrt }; - -// Returns the absolute value, its square root or its reciprocal square root. +// Returns the absolute value or its square root. Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, - ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { + ImplicitLocOpBuilder &b, bool returnSqrt = false) { Value one = b.create(real.getType(), b.getFloatAttr(real.getType(), 1.0)); @@ -45,13 +43,7 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, Value ratioSqPlusOne = b.create(ratioSq, one, fmf); Value result; - if (fn == AbsFn::rsqrt) { - ratioSqPlusOne = b.create(ratioSqPlusOne, fmf); - min = b.create(min, fmf); - max = b.create(max, fmf); - } - - if (fn == AbsFn::sqrt) { + if (returnSqrt) { Value quarter = b.create( real.getType(), b.getFloatAttr(real.getType(), 0.25)); // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. @@ -871,7 +863,7 @@ struct SqrtOpConversion : public OpConversionPattern { Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); - Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); + Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true); Value argArg = b.create(imag, real, fmf); Value sqrtArg = b.create(argArg, half, fmf); Value cos = b.create(sqrtArg, fmf); @@ -1155,74 +1147,18 @@ struct RsqrtOpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); - arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); - - auto cst = [&](APFloat v) { - return b.create(elementType, - b.getFloatAttr(elementType, v)); - }; - const auto &floatSemantics = elementType.getFloatSemantics(); - Value zero = cst(APFloat::getZero(floatSemantics)); - Value inf = cst(APFloat::getInf(floatSemantics)); - Value negHalf = b.create( - elementType, b.getFloatAttr(elementType, -0.5)); - Value nan = cst(APFloat::getNaN(floatSemantics)); - - Value real = b.create(elementType, adaptor.getComplex()); - Value imag = b.create(elementType, adaptor.getComplex()); - Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); - Value argArg = b.create(imag, real, fmf); - Value rsqrtArg = b.create(argArg, negHalf, fmf); - Value cos = b.create(rsqrtArg, fmf); - Value sin = b.create(rsqrtArg, fmf); - - Value resultReal = b.create(absRsqrt, cos, fmf); - Value resultImag = b.create(absRsqrt, sin, fmf); - - if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | - arith::FastMathFlags::ninf)) { - Value negOne = b.create( - elementType, b.getFloatAttr(elementType, -1)); - - Value realSignedZero = b.create(zero, real, fmf); - Value imagSignedZero = b.create(zero, imag, fmf); - Value negImagSignedZero = - b.create(negOne, imagSignedZero, fmf); + Value c = builder.create( + elementType, builder.getFloatAttr(elementType, -0.5)); + Value d = builder.create( + elementType, builder.getFloatAttr(elementType, 0)); - Value absReal = b.create(real, fmf); - Value absImag = b.create(imag, fmf); - - Value absImagIsInf = - b.create(arith::CmpFPredicate::OEQ, absImag, inf, fmf); - Value realIsNan = - b.create(arith::CmpFPredicate::UNO, real, real, fmf); - Value realIsInf = - b.create(arith::CmpFPredicate::OEQ, absReal, inf, fmf); - Value inIsNanInf = b.create(absImagIsInf, realIsNan); - - Value resultIsZero = b.create(inIsNanInf, realIsInf); - - resultReal = - b.create(resultIsZero, realSignedZero, resultReal); - resultImag = b.create(resultIsZero, negImagSignedZero, - resultImag); - } - - Value isRealZero = - b.create(arith::CmpFPredicate::OEQ, real, zero, fmf); - Value isImagZero = - b.create(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value isZero = b.create(isRealZero, isImagZero); - - resultReal = b.create(isZero, inf, resultReal); - resultImag = b.create(isZero, nan, resultImag); - - rewriter.replaceOpWithNewOp(op, type, resultReal, - resultImag); + rewriter.replaceOp(op, + {powOpConversionImpl(builder, type, adaptor.getComplex(), + c, d, op.getFastmath())}); return success(); } }; diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 8b4ea9777f797..e0e7cdadd317d 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -837,21 +837,6 @@ func.func @complex_rsqrt(%arg: complex) -> complex { return %rsqrt : complex } -// CHECK-COUNT-5: arith.select -// CHECK-NOT: arith.select - -// ----- - -// CHECK-LABEL: func @complex_rsqrt_nnan_ninf -// CHECK-SAME: %[[ARG:.*]]: complex -func.func @complex_rsqrt_nnan_ninf(%arg: complex) -> complex { - %sqrt = complex.rsqrt %arg fastmath : complex - return %sqrt : complex -} - -// CHECK-COUNT-3: arith.select -// CHECK-NOT: arith.select - // ----- // CHECK-LABEL: func.func @complex_angle @@ -2118,4 +2103,4 @@ func.func @complex_tanh_with_fmf(%arg: complex) -> complex { // CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath : f32 -// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex +// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex \ No newline at end of file