diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 4c9dad9e2c173..6e0eddab0e3aa 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -26,29 +26,59 @@ namespace mlir { using namespace mlir; namespace { +// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. struct AbsOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto type = op.getType(); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value real = - rewriter.create(loc, type, adaptor.getComplex()); - Value imag = - rewriter.create(loc, type, adaptor.getComplex()); - Value realSqr = - rewriter.create(loc, real, real, fmf.getValue()); - Value imagSqr = - rewriter.create(loc, imag, imag, fmf.getValue()); - Value sqNorm = - rewriter.create(loc, realSqr, imagSqr, fmf.getValue()); - - rewriter.replaceOpWithNewOp(op, sqNorm); + Type elementType = op.getType(); + Value arg = adaptor.getComplex(); + + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1.0)); + + Value real = b.create(elementType, arg); + Value imag = b.create(elementType, arg); + + Value realIsZero = + b.create(arith::CmpFPredicate::OEQ, real, zero); + Value imagIsZero = + b.create(arith::CmpFPredicate::OEQ, imag, zero); + + // Real > Imag + Value imagDivReal = b.create(imag, real, fmf.getValue()); + Value imagSq = + b.create(imagDivReal, imagDivReal, fmf.getValue()); + Value imagSqPlusOne = b.create(imagSq, one, fmf.getValue()); + Value imagSqrt = b.create(imagSqPlusOne, fmf.getValue()); + Value realAbs = b.create(real, fmf.getValue()); + Value absImag = b.create(imagSqrt, realAbs, fmf.getValue()); + + // Real <= Imag + Value realDivImag = b.create(real, imag, fmf.getValue()); + Value realSq = + b.create(realDivImag, realDivImag, fmf.getValue()); + Value realSqPlusOne = b.create(realSq, one, fmf.getValue()); + Value realSqrt = b.create(realSqPlusOne, fmf.getValue()); + Value imagAbs = b.create(imag, fmf.getValue()); + Value absReal = b.create(realSqrt, imagAbs, fmf.getValue()); + + rewriter.replaceOpWithNewOp( + op, realIsZero, imag, + b.create( + imagIsZero, real, + b.create( + b.create(arith::CmpFPredicate::OGT, real, imag), + absImag, absReal))); + return success(); } }; diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 8fa29ea43854a..de519f6a706ab 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -7,13 +7,30 @@ func.func @complex_abs(%arg: complex) -> f32 { %abs = complex.abs %arg: complex return %abs : f32 } + +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex -// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 -// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 -// CHECK: return %[[NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 +// CHECK: return %[[ABS3]] : f32 // ----- @@ -241,12 +258,28 @@ func.func @complex_log(%arg: complex) -> complex { %log = complex.log %arg: complex return %log : complex } +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex @@ -469,12 +502,28 @@ func.func @complex_sign(%arg: complex) -> complex { // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32 // CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32 // CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex @@ -716,13 +765,29 @@ func.func @complex_abs_with_fmf(%arg: complex) -> f32 { %abs = complex.abs %arg fastmath : complex return %abs : f32 } +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex -// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath : f32 -// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 -// CHECK: return %[[NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 +// CHECK: return %[[ABS3]] : f32 // ----- @@ -807,12 +872,28 @@ func.func @complex_log_with_fmf(%arg: complex) -> complex { %log = complex.log %arg fastmath : complex return %log : complex } +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir index 9983dd46f0943..c01f3698c7386 100644 --- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir +++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -6,12 +6,31 @@ func.func @complex_abs(%arg: complex) -> f32 { %abs = complex.abs %arg: complex return %abs : f32 } +// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 // CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]] // CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]] -// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32 -// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32 -// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32 +// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32 + +// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32 +// CHECK: %[[REAL_ABS:.*]] = llvm.intr.fabs(%[[REAL]]) : (f32) -> f32 +// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL_ABS]] : f32 + +// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32 +// CHECK: %[[IMAG_ABS:.*]] = llvm.intr.fabs(%[[IMAG]]) : (f32) -> f32 +// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG_ABS]] : f32 + +// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32 +// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32 +// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32 // CHECK: llvm.return %[[NORM]] : f32 // CHECK-LABEL: llvm.func @complex_eq diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index 349b92a7aefa2..efdf057f08df2 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -106,6 +106,27 @@ func.func @angle(%arg: complex) -> f32 { func.return %angle : f32 } +func.func @test_element_f64(%input: tensor>, + %func: (complex) -> f64) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %size = tensor.dim %input, %c0: tensor> + + scf.for %i = %c0 to %size step %c1 { + %elem = tensor.extract %input[%i]: tensor> + + %val = func.call_indirect %func(%elem) : (complex) -> f64 + vector.print %val : f64 + scf.yield + } + func.return +} + +func.func @abs(%arg: complex) -> f64 { + %abs = complex.abs %arg : complex + func.return %abs : f64 +} + func.func @entry() { // complex.sqrt test %sqrt_test = arith.constant dense<[ @@ -300,5 +321,32 @@ func.func @entry() { call @test_element(%angle_test_cast, %angle_func) : (tensor>, (complex) -> f32) -> () + // complex.abs test + %abs_test = arith.constant dense<[ + (1.0, 1.0), + // CHECK: 1.414 + (1.0e300, 1.0e300), + // CHECK-NEXT: 1.41421e+300 + (1.0e-300, 1.0e-300), + // CHECK-NEXT: 1.41421e-300 + (5.0, 0.0), + // CHECK-NEXT: 5 + (0.0, 6.0), + // CHECK-NEXT: 6 + (7.0, 8.0), + // CHECK-NEXT: 10.6301 + (-1.0, -1.0), + // CHECK-NEXT: 1.414 + (-1.0e300, -1.0e300) + // CHECK-NEXT: 1.41421e+300 + ]> : tensor<8xcomplex> + %abs_test_cast = tensor.cast %abs_test + : tensor<8xcomplex> to tensor> + + %abs_func = func.constant @abs : (complex) -> f64 + + call @test_element_f64(%abs_test_cast, %abs_func) + : (tensor>, (complex) -> f64) -> () + func.return }