diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 17f64f1b65b7c..e18b27702e4ec 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -918,6 +918,7 @@ struct SignOpConversion : public OpConversionPattern { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); @@ -928,9 +929,9 @@ struct SignOpConversion : public OpConversionPattern { Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, zero); Value isZero = b.create(realIsZero, imagIsZero); - auto abs = b.create(elementType, adaptor.getComplex()); - Value realSign = b.create(real, abs); - Value imagSign = b.create(imag, abs); + auto abs = b.create(elementType, adaptor.getComplex(), fmf); + Value realSign = b.create(real, abs, fmf); + Value imagSign = b.create(imag, abs, fmf); Value sign = b.create(type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.getComplex(), sign); diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index bac94aae6b746..112918b20d330 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -1880,3 +1880,46 @@ func.func @complex_sin_with_fmf(%arg: complex) -> complex { // CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] fastmath // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: func @complex_sign_with_fmf +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_sign_with_fmf(%arg: complex) -> complex { + %sign = complex.sign %arg fastmath : complex + return %sign : complex +} + +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] fastmath : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] fastmath : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] fastmath : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] fastmath : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32 +// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] fastmath : f32 +// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] fastmath : f32 +// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex +// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex +// CHECK: return %[[RESULT]] : complex