Skip to content

Commit 221f438

Browse files
[flang][OpenMP] Add support for complex reductions (#87488)
This adds support for complex type to the OpenMP reductions. Note that some more work would be needed to give decent error messages when complex is used in ways that need client supplied functions (e.g. MAX or MIN). It does fail these with a not so user friendly message at present.
1 parent 364028a commit 221f438

File tree

4 files changed

+137
-6
lines changed

4 files changed

+137
-6
lines changed

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#include "ReductionProcessor.h"
1414

1515
#include "flang/Lower/AbstractConverter.h"
16+
#include "flang/Lower/ConvertType.h"
1617
#include "flang/Lower/SymbolMap.h"
18+
#include "flang/Optimizer/Builder/Complex.h"
1719
#include "flang/Optimizer/Builder/HLFIRTools.h"
1820
#include "flang/Optimizer/Builder/Todo.h"
1921
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -131,7 +133,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
131133
fir::FirOpBuilder &builder) {
132134
type = fir::unwrapRefType(type);
133135
if (!fir::isa_integer(type) && !fir::isa_real(type) &&
134-
!mlir::isa<fir::LogicalType>(type))
136+
!fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
135137
TODO(loc, "Reduction of some types is not supported");
136138
switch (redId) {
137139
case ReductionIdentifier::MAX: {
@@ -175,6 +177,16 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
175177
case ReductionIdentifier::OR:
176178
case ReductionIdentifier::EQV:
177179
case ReductionIdentifier::NEQV:
180+
if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
181+
mlir::Type realTy =
182+
Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
183+
mlir::Value initRe = builder.createRealConstant(
184+
loc, realTy, getOperationIdentity(redId, loc));
185+
mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
186+
187+
return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
188+
initIm);
189+
}
178190
if (type.isa<mlir::FloatType>())
179191
return builder.create<mlir::arith::ConstantOp>(
180192
loc, type,
@@ -229,13 +241,13 @@ mlir::Value ReductionProcessor::createScalarCombiner(
229241
break;
230242
case ReductionIdentifier::ADD:
231243
reductionOp =
232-
getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
233-
builder, type, loc, op1, op2);
244+
getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
245+
fir::AddcOp>(builder, type, loc, op1, op2);
234246
break;
235247
case ReductionIdentifier::MULTIPLY:
236248
reductionOp =
237-
getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
238-
builder, type, loc, op1, op2);
249+
getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
250+
fir::MulcOp>(builder, type, loc, op1, op2);
239251
break;
240252
case ReductionIdentifier::AND: {
241253
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ class ReductionProcessor {
9797
fir::FirOpBuilder &builder);
9898

9999
template <typename FloatOp, typename IntegerOp>
100+
static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
101+
mlir::Type type, mlir::Location loc,
102+
mlir::Value op1, mlir::Value op2);
103+
template <typename FloatOp, typename IntegerOp, typename ComplexOp>
100104
static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
101105
mlir::Type type, mlir::Location loc,
102106
mlir::Value op1, mlir::Value op2);
@@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
136140
mlir::Value op1, mlir::Value op2) {
137141
type = fir::unwrapRefType(type);
138142
assert(type.isIntOrIndexOrFloat() &&
139-
"only integer and float types are currently supported");
143+
"only integer, float and complex types are currently supported");
140144
if (type.isIntOrIndex())
141145
return builder.create<IntegerOp>(loc, op1, op2);
142146
return builder.create<FloatOp>(loc, op1, op2);
143147
}
144148

149+
template <typename FloatOp, typename IntegerOp, typename ComplexOp>
150+
mlir::Value
151+
ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
152+
mlir::Type type, mlir::Location loc,
153+
mlir::Value op1, mlir::Value op2) {
154+
assert(type.isIntOrIndexOrFloat() ||
155+
fir::isa_complex(type) &&
156+
"only integer, float and complex types are currently supported");
157+
if (type.isIntOrIndex())
158+
return builder.create<IntegerOp>(loc, op1, op2);
159+
if (fir::isa_real(type))
160+
return builder.create<FloatOp>(loc, op1, op2);
161+
return builder.create<ComplexOp>(loc, op1, op2);
162+
}
163+
145164
} // namespace omp
146165
} // namespace lower
147166
} // namespace Fortran
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
3+
4+
!CHECK-LABEL: omp.declare_reduction
5+
!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
6+
!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
7+
!CHECK: %[[C0_1:.*]] = arith.constant 1.000000e+00 : f64
8+
!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
9+
!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
10+
!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
11+
!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
12+
!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
13+
!CHECK: } combiner {
14+
!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
15+
!CHECK: %[[RES:.*]] = fir.mulc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
16+
!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
17+
!CHECK: }
18+
19+
!CHECK-LABEL: func.func @_QPsimple_complex_mul
20+
!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
21+
!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_mulEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
22+
!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
23+
!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
24+
!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
25+
!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
26+
!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
27+
!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
28+
!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
29+
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
30+
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
31+
!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
32+
!CHECK: %[[C_INCR_IM:.*]] = arith.constant -2.000000e+00 : f64
33+
!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
34+
!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
35+
!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
36+
!CHECK: %[[RES:.+]] = fir.mulc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
37+
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
38+
!CHECK: omp.terminator
39+
!CHECK: }
40+
!CHECK: return
41+
subroutine simple_complex_mul
42+
complex(8) :: c
43+
c = 0
44+
45+
!$omp parallel reduction(*:c)
46+
c = c * cmplx(1, -2)
47+
!$omp end parallel
48+
49+
print *, c
50+
end subroutine
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
3+
4+
!CHECK-LABEL: omp.declare_reduction
5+
!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
6+
!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
7+
!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64
8+
!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
9+
!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
10+
!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
11+
!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
12+
!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
13+
!CHECK: } combiner {
14+
!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
15+
!CHECK: %[[RES:.*]] = fir.addc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
16+
!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
17+
!CHECK: }
18+
19+
!CHECK-LABEL: func.func @_QPsimple_complex_add
20+
!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
21+
!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_addEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
22+
!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
23+
!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
24+
!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
25+
!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
26+
!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
27+
!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
28+
!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
29+
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
30+
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
31+
!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
32+
!CHECK: %[[C_INCR_IM:.*]] = arith.constant 0.000000e+00 : f64
33+
!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
34+
!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
35+
!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
36+
!CHECK: %[[RES:.+]] = fir.addc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
37+
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
38+
!CHECK: omp.terminator
39+
!CHECK: }
40+
!CHECK: return
41+
subroutine simple_complex_add
42+
complex(8) :: c
43+
c = 0
44+
45+
!$omp parallel reduction(+:c)
46+
c = c + 1
47+
!$omp end parallel
48+
49+
print *, c
50+
end subroutine

0 commit comments

Comments
 (0)