Skip to content

Commit be5b666

Browse files
authored
[mlir][complex] Support fastmath in the binary op conversion. (#65702)
Complex dialect arithmetic operations are now able to recognize the given fastmath flags. This PR lets the conversion from complex to standard keep the fastmath flag passed to arith dialect ops. See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
1 parent 8f2ffb1 commit be5b666

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,16 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
137137
auto type = cast<ComplexType>(adaptor.getLhs().getType());
138138
auto elementType = cast<FloatType>(type.getElementType());
139139
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
140+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
140141

141142
Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
142143
Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
143-
Value resultReal =
144-
b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
144+
Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
145+
fmf.getValue());
145146
Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
146147
Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
147-
Value resultImag =
148-
b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
148+
Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
149+
fmf.getValue());
149150
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
150151
resultImag);
151152
return success();

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,37 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
723723
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
724724
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
725725
// CHECK: return %[[NORM]] : f32
726+
727+
// -----
728+
729+
// CHECK-LABEL: func @complex_add_with_fmf
730+
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
731+
func.func @complex_add_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
732+
%add = complex.add %lhs, %rhs fastmath<nnan,contract> : complex<f32>
733+
return %add : complex<f32>
734+
}
735+
// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
736+
// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
737+
// CHECK: %[[RESULT_REAL:.*]] = arith.addf %[[REAL_LHS]], %[[REAL_RHS]] fastmath<nnan,contract> : f32
738+
// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
739+
// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
740+
// CHECK: %[[RESULT_IMAG:.*]] = arith.addf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
741+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
742+
// CHECK: return %[[RESULT]] : complex<f32>
743+
744+
// -----
745+
746+
// CHECK-LABEL: func @complex_sub_with_fmf
747+
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
748+
func.func @complex_sub_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
749+
%sub = complex.sub %lhs, %rhs fastmath<nnan,contract> : complex<f32>
750+
return %sub : complex<f32>
751+
}
752+
// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
753+
// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
754+
// CHECK: %[[RESULT_REAL:.*]] = arith.subf %[[REAL_LHS]], %[[REAL_RHS]] fastmath<nnan,contract> : f32
755+
// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
756+
// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
757+
// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
758+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
759+
// CHECK: return %[[RESULT]] : complex<f32>

0 commit comments

Comments
 (0)