Skip to content

Commit a636458

Browse files
authored
Merge pull request #793 from flang-compiler/ml-sum
Add lowering of product and sum intrinsics
2 parents 7b7e86f + bfe0424 commit a636458

File tree

5 files changed

+649
-59
lines changed

5 files changed

+649
-59
lines changed

flang/include/flang/Lower/ReductionRuntime.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,30 @@ void genMinvalDim(Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
102102
mlir::Value resultBox, mlir::Value arrayBox, mlir::Value dim,
103103
mlir::Value maskBox);
104104

105+
/// Generate call to Product intrinsic runtime routine. This is the version
106+
/// that does not take a dim argument.
107+
mlir::Value genProduct(Fortran::lower::FirOpBuilder &builder,
108+
mlir::Location loc, mlir::Value arrayBox,
109+
mlir::Value maskBox, mlir::Value resultBox);
110+
111+
/// Generate call to Product intrinsic runtime routine. This is the version
112+
/// that takes arrays of any rank with a dim argument specified.
113+
void genProductDim(Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
114+
mlir::Value resultBox, mlir::Value arrayBox, mlir::Value dim,
115+
mlir::Value maskBox);
116+
117+
/// Generate call to Sum intrinsic runtime routine. This is the version
118+
/// that does not take a dim argument.
119+
mlir::Value genSum(Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
120+
mlir::Value arrayBox, mlir::Value maskBox,
121+
mlir::Value resultBox);
122+
123+
/// Generate call to Sum intrinsic runtime routine. This is the version
124+
/// that takes arrays of any rank with a dim argument specified.
125+
void genSumDim(Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
126+
mlir::Value resultBox, mlir::Value arrayBox, mlir::Value dim,
127+
mlir::Value maskBox);
128+
105129
} // namespace Fortran::lower
106130

107131
#endif // FORTRAN_LOWER_REDUCTIONRUNTIME_H

flang/lib/Lower/IntrinsicCall.cpp

Lines changed: 118 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,92 @@ static bool isAbsent(const fir::ExtendedValue &exv) {
106106
return !fir::getBase(exv);
107107
}
108108

109-
/// Process calls to Maxval, Minval intrinsic functions
109+
/// Process calls to Maxval, Minval, Product, Sum intrinsic functions that
110+
/// take a DIM argument.
111+
template <typename FD>
112+
static fir::ExtendedValue
113+
genFuncDim(FD funcDim, mlir::Type resultType,
114+
Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
115+
Fortran::lower::StatementContext *stmtCtx, llvm::StringRef errMsg,
116+
mlir::Value array, fir::ExtendedValue dimArg, mlir::Value mask,
117+
int rank) {
118+
119+
// Create mutable fir.box to be passed to the runtime for the result.
120+
auto resultArrayType = builder.getVarLenSeqTy(resultType, rank - 1);
121+
auto resultMutableBox =
122+
Fortran::lower::createTempMutableBox(builder, loc, resultArrayType);
123+
auto resultIrBox =
124+
Fortran::lower::getMutableIRBox(builder, loc, resultMutableBox);
125+
126+
auto dim = isAbsent(dimArg)
127+
? builder.createIntegerConstant(loc, builder.getIndexType(), 0)
128+
: fir::getBase(dimArg);
129+
funcDim(builder, loc, resultIrBox, array, dim, mask);
130+
131+
auto res = Fortran::lower::genMutableBoxRead(builder, loc, resultMutableBox);
132+
return res.match(
133+
[&](const fir::ArrayBoxValue &box) -> fir::ExtendedValue {
134+
// Add cleanup code
135+
assert(stmtCtx);
136+
auto *bldr = &builder;
137+
auto temp = box.getAddr();
138+
stmtCtx->attachCleanup(
139+
[=]() { bldr->create<fir::FreeMemOp>(loc, temp); });
140+
return box;
141+
},
142+
[&](const auto &) -> fir::ExtendedValue {
143+
fir::emitFatalError(loc, errMsg);
144+
});
145+
}
146+
147+
/// Process calls to Product, Sum intrinsic functions
148+
template <typename FN, typename FD>
149+
static fir::ExtendedValue
150+
genProdOrSum(FN func, FD funcDim, mlir::Type resultType,
151+
Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
152+
Fortran::lower::StatementContext *stmtCtx, llvm::StringRef errMsg,
153+
llvm::ArrayRef<fir::ExtendedValue> args) {
154+
155+
assert(args.size() == 3);
156+
157+
// Handle required array argument
158+
fir::BoxValue arryTmp = builder.createBox(loc, args[0]);
159+
mlir::Value array = fir::getBase(arryTmp);
160+
int rank = arryTmp.rank();
161+
assert(rank >= 1);
162+
163+
// Handle optional mask argument
164+
auto mask = isAbsent(args[2])
165+
? builder.create<fir::AbsentOp>(
166+
loc, fir::BoxType::get(builder.getI1Type()))
167+
: builder.createBox(loc, args[2]);
168+
169+
bool absentDim = isAbsent(args[1]);
170+
171+
// We call the type specific versions because the result is scalar
172+
// in the case below.
173+
if (absentDim || rank == 1) {
174+
auto ty = array.getType();
175+
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
176+
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
177+
if (fir::isa_complex(eleTy)) {
178+
auto result = builder.createTemporary(loc, eleTy);
179+
func(builder, loc, array, mask, result);
180+
return builder.create<fir::LoadOp>(loc, result);
181+
}
182+
auto resultBox = builder.create<fir::AbsentOp>(
183+
loc, fir::BoxType::get(builder.getI1Type()));
184+
return func(builder, loc, array, mask, resultBox);
185+
}
186+
// Handle Product/Sum cases that have an array result.
187+
return genFuncDim(funcDim, resultType, builder, loc, stmtCtx, errMsg, array,
188+
args[1], mask, rank);
189+
}
190+
191+
/// Process calls to Maxval, Minval, Product, Sum intrinsic functions
110192
template <typename FN, typename FD, typename FC>
111193
static fir::ExtendedValue
112-
genExtremumval(FN func, FD funcDim, FC funcChar, mlir::Type resultType,
194+
genExtremumVal(FN func, FD funcDim, FC funcChar, mlir::Type resultType,
113195
Fortran::lower::FirOpBuilder &builder, mlir::Location loc,
114196
Fortran::lower::StatementContext *stmtCtx,
115197
llvm::StringRef errMsg,
@@ -164,43 +246,9 @@ genExtremumval(FN func, FD funcDim, FC funcChar, mlir::Type resultType,
164246
});
165247
}
166248

167-
// Note: The Min/Maxval cases below have an array result.
168-
// Create mutable fir.box to be passed to the runtime for the result.
169-
auto resultArrayType = builder.getVarLenSeqTy(resultType, rank - 1);
170-
auto resultMutableBox =
171-
Fortran::lower::createTempMutableBox(builder, loc, resultArrayType);
172-
auto resultIrBox =
173-
Fortran::lower::getMutableIRBox(builder, loc, resultMutableBox);
174-
175-
auto dim = absentDim
176-
? builder.createIntegerConstant(loc, builder.getIndexType(), 0)
177-
: fir::getBase(args[1]);
178-
funcDim(builder, loc, resultIrBox, array, dim, mask);
179-
180-
auto res = Fortran::lower::genMutableBoxRead(builder, loc, resultMutableBox);
181-
182-
return res.match(
183-
[&](const fir::ArrayBoxValue &box) -> fir::ExtendedValue {
184-
// Add cleanup code
185-
assert(stmtCtx);
186-
auto *bldr = &builder;
187-
auto temp = box.getAddr();
188-
stmtCtx->attachCleanup(
189-
[=]() { bldr->create<fir::FreeMemOp>(loc, temp); });
190-
return box;
191-
},
192-
[&](const fir::CharArrayBoxValue &box) -> fir::ExtendedValue {
193-
// Add cleanup code
194-
assert(stmtCtx);
195-
auto *bldr = &builder;
196-
auto temp = box.getAddr();
197-
stmtCtx->attachCleanup(
198-
[=]() { bldr->create<fir::FreeMemOp>(loc, temp); });
199-
return box;
200-
},
201-
[&](const auto &) -> fir::ExtendedValue {
202-
fir::emitFatalError(loc, errMsg);
203-
});
249+
// Handle Min/Maxval cases that have an array result.
250+
return genFuncDim(funcDim, resultType, builder, loc, stmtCtx, errMsg, array,
251+
args[1], mask, rank);
204252
}
205253

206254
/// Process calls to Minloc, Maxloc intrinsic functions
@@ -376,12 +424,14 @@ struct IntrinsicLibrary {
376424
mlir::Value genModulo(mlir::Type, llvm::ArrayRef<mlir::Value>);
377425
mlir::Value genNint(mlir::Type, llvm::ArrayRef<mlir::Value>);
378426
fir::ExtendedValue genPresent(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
427+
fir::ExtendedValue genProduct(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
379428
mlir::Value genRRSpacing(mlir::Type resultType,
380429
llvm::ArrayRef<mlir::Value> args);
381430
fir::ExtendedValue genScan(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
382431
mlir::Value genSign(mlir::Type, llvm::ArrayRef<mlir::Value>);
383432
mlir::Value genSpacing(mlir::Type resultType,
384433
llvm::ArrayRef<mlir::Value> args);
434+
fir::ExtendedValue genSum(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
385435
fir::ExtendedValue genTrim(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
386436
fir::ExtendedValue genVerify(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
387437
/// Implement all conversion functions like DBLE, the first argument is
@@ -567,6 +617,10 @@ static constexpr IntrinsicHandler handlers[]{
567617
&I::genPresent,
568618
{{{"a", asInquired}}},
569619
/*isElemental=*/false},
620+
{"product",
621+
&I::genProduct,
622+
{{{"array", asAddr}, {"dim", asValue}, {"mask", asAddr}}},
623+
/*isElemental=*/false},
570624
{"rrspacing", &I::genRRSpacing},
571625
{"scan",
572626
&I::genScan,
@@ -577,6 +631,10 @@ static constexpr IntrinsicHandler handlers[]{
577631
/*isElemental=*/true},
578632
{"sign", &I::genSign},
579633
{"spacing", &I::genSpacing},
634+
{"sum",
635+
&I::genSum,
636+
{{{"array", asAddr}, {"dim", asValue}, {"mask", asAddr}}},
637+
/*isElemental=*/false},
580638
{"trim", &I::genTrim, {{{"string", asAddr}}}, /*isElemental=*/false},
581639
{"verify",
582640
&I::genVerify,
@@ -1908,6 +1966,14 @@ IntrinsicLibrary::genPresent(mlir::Type,
19081966
fir::getBase(args[0]));
19091967
}
19101968

1969+
// PRODUCT
1970+
fir::ExtendedValue
1971+
IntrinsicLibrary::genProduct(mlir::Type resultType,
1972+
llvm::ArrayRef<fir::ExtendedValue> args) {
1973+
return genProdOrSum(Fortran::lower::genProduct, Fortran::lower::genProductDim,
1974+
resultType, builder, loc, stmtCtx,
1975+
"unexpected result for Product", args);
1976+
}
19111977
// RRSPACING
19121978
mlir::Value IntrinsicLibrary::genRRSpacing(mlir::Type resultType,
19131979
llvm::ArrayRef<mlir::Value> args) {
@@ -2019,8 +2085,8 @@ mlir::Value IntrinsicLibrary::genSign(mlir::Type resultType,
20192085
// TODO: Requirements when second argument is +0./0.
20202086
auto zero = builder.createRealZeroConstant(loc, resultType);
20212087
auto neg = builder.create<fir::NegfOp>(loc, abs);
2022-
auto cmp =
2023-
builder.create<mlir::CmpFOp>(loc, mlir::CmpFPredicate::OLT, args[1], zero);
2088+
auto cmp = builder.create<mlir::CmpFOp>(loc, mlir::CmpFPredicate::OLT,
2089+
args[1], zero);
20242090
return builder.create<mlir::SelectOp>(loc, cmp, neg, abs);
20252091
}
20262092

@@ -2034,6 +2100,15 @@ mlir::Value IntrinsicLibrary::genSpacing(mlir::Type resultType,
20342100
Fortran::lower::genSpacing(builder, loc, fir::getBase(args[0])));
20352101
}
20362102

2103+
// SUM
2104+
fir::ExtendedValue
2105+
IntrinsicLibrary::genSum(mlir::Type resultType,
2106+
llvm::ArrayRef<fir::ExtendedValue> args) {
2107+
return genProdOrSum(Fortran::lower::genSum, Fortran::lower::genSumDim,
2108+
resultType, builder, loc, stmtCtx,
2109+
"unexpected result for Sum", args);
2110+
}
2111+
20372112
// TRIM
20382113
fir::ExtendedValue
20392114
IntrinsicLibrary::genTrim(mlir::Type resultType,
@@ -2216,7 +2291,7 @@ IntrinsicLibrary::genMaxloc(mlir::Type resultType,
22162291
fir::ExtendedValue
22172292
IntrinsicLibrary::genMaxval(mlir::Type resultType,
22182293
llvm::ArrayRef<fir::ExtendedValue> args) {
2219-
return genExtremumval(Fortran::lower::genMaxval, Fortran::lower::genMaxvalDim,
2294+
return genExtremumVal(Fortran::lower::genMaxval, Fortran::lower::genMaxvalDim,
22202295
Fortran::lower::genMaxvalChar, resultType, builder, loc,
22212296
stmtCtx, "unexpected result for Maxval", args);
22222297
}
@@ -2234,7 +2309,7 @@ IntrinsicLibrary::genMinloc(mlir::Type resultType,
22342309
fir::ExtendedValue
22352310
IntrinsicLibrary::genMinval(mlir::Type resultType,
22362311
llvm::ArrayRef<fir::ExtendedValue> args) {
2237-
return genExtremumval(Fortran::lower::genMinval, Fortran::lower::genMinvalDim,
2312+
return genExtremumVal(Fortran::lower::genMinval, Fortran::lower::genMinvalDim,
22382313
Fortran::lower::genMinvalChar, resultType, builder, loc,
22392314
stmtCtx, "unexpected result for Minval", args);
22402315
}

flang/lib/Lower/RTBuilder.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,20 @@ constexpr TypeBuilderFunc getModel<bool &>() {
207207
};
208208
}
209209
template <>
210+
constexpr TypeBuilderFunc getModel<std::complex<float> &>() {
211+
return [](mlir::MLIRContext *context) -> mlir::Type {
212+
auto ty = mlir::ComplexType::get(mlir::FloatType::getF32(context));
213+
return fir::ReferenceType::get(ty);
214+
};
215+
}
216+
template <>
217+
constexpr TypeBuilderFunc getModel<std::complex<double> &>() {
218+
return [](mlir::MLIRContext *context) -> mlir::Type {
219+
auto ty = mlir::ComplexType::get(mlir::FloatType::getF64(context));
220+
return fir::ReferenceType::get(ty);
221+
};
222+
}
223+
template <>
210224
constexpr TypeBuilderFunc getModel<c_float_complex_t>() {
211225
return [](mlir::MLIRContext *context) -> mlir::Type {
212226
return fir::ComplexType::get(context, sizeof(float));

0 commit comments

Comments
 (0)