10
10
// into the calling function.
11
11
// ===----------------------------------------------------------------------===//
12
12
13
+ #include " flang/Optimizer/Builder/Complex.h"
13
14
#include " flang/Optimizer/Builder/FIRBuilder.h"
14
15
#include " flang/Optimizer/Builder/HLFIRTools.h"
15
16
#include " flang/Optimizer/Dialect/FIRDialect.h"
@@ -90,13 +91,248 @@ class TransposeAsElementalConversion
90
91
}
91
92
};
92
93
94
+ // Expand the SUM(DIM=CONSTANT) operation into .
95
+ class SumAsElementalConversion : public mlir ::OpRewritePattern<hlfir::SumOp> {
96
+ public:
97
+ using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
98
+
99
+ llvm::LogicalResult
100
+ matchAndRewrite (hlfir::SumOp sum,
101
+ mlir::PatternRewriter &rewriter) const override {
102
+ mlir::Location loc = sum.getLoc ();
103
+ fir::FirOpBuilder builder{rewriter, sum.getOperation ()};
104
+ hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType ());
105
+ assert (expr && " expected an expression type for the result of hlfir.sum" );
106
+ mlir::Type elementType = expr.getElementType ();
107
+ hlfir::Entity array = hlfir::Entity{sum.getArray ()};
108
+ mlir::Value mask = sum.getMask ();
109
+ mlir::Value dim = sum.getDim ();
110
+ int64_t dimVal = fir::getIntIfConstant (dim).value_or (0 );
111
+ assert (dimVal > 0 && " DIM must be present and a positive constant" );
112
+ mlir::Value resultShape, dimExtent;
113
+ std::tie (resultShape, dimExtent) =
114
+ genResultShape (loc, builder, array, dimVal);
115
+
116
+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
117
+ mlir::ValueRange inputIndices) -> hlfir::Entity {
118
+ // Loop over all indices in the DIM dimension, and reduce all values.
119
+ // We do not need to create the reduction loop always: if we can
120
+ // slice the input array given the inputIndices, then we can
121
+ // just apply a new SUM operation (total reduction) to the slice.
122
+ // For the time being, generate the explicit loop because the slicing
123
+ // requires generating an elemental operation for the input array
124
+ // (and the mask, if present).
125
+ // TODO: produce the slices and new SUM after adding a pattern
126
+ // for expanding total reduction SUM case.
127
+ mlir::Type indexType = builder.getIndexType ();
128
+ auto one = builder.createIntegerConstant (loc, indexType, 1 );
129
+ auto ub = builder.createConvert (loc, indexType, dimExtent);
130
+
131
+ // Initial value for the reduction.
132
+ mlir::Value initValue = genInitValue (loc, builder, elementType);
133
+
134
+ // The reduction loop may be unordered if FastMathFlags::reassoc
135
+ // transformations are allowed. The integer reduction is always
136
+ // unordered.
137
+ bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
138
+ static_cast <bool >(sum.getFastmath () &
139
+ mlir::arith::FastMathFlags::reassoc);
140
+
141
+ // If the mask is present and is a scalar, then we'd better load its value
142
+ // outside of the reduction loop making the loop unswitching easier.
143
+ // Maybe it is worth hoisting it from the elemental operation as well.
144
+ mlir::Value isPresentPred, maskValue;
145
+ if (mask) {
146
+ if (mlir::isa<fir::BaseBoxType>(mask.getType ())) {
147
+ // MASK represented by a box might be dynamically optional,
148
+ // so we have to check for its presence before accessing it.
149
+ isPresentPred =
150
+ builder.create <fir::IsPresentOp>(loc, builder.getI1Type (), mask);
151
+ }
152
+
153
+ if (hlfir::Entity{mask}.isScalar ())
154
+ maskValue = genMaskValue (loc, builder, mask, isPresentPred, {});
155
+ }
156
+
157
+ // NOTE: the outer elemental operation may be lowered into
158
+ // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
159
+ // loop may appear disjoint from the workshare loop nest.
160
+ // Moreover, the inner loop is not strictly nested (due to the reduction
161
+ // starting value initialization), and the above omp dialect operations
162
+ // cannot produce results.
163
+ // It is unclear what we should do about it yet.
164
+ auto doLoop = builder.create <fir::DoLoopOp>(
165
+ loc, one, ub, one, isUnordered, /* finalCountValue=*/ false ,
166
+ mlir::ValueRange{initValue});
167
+
168
+ // Address the input array using the reduction loop's IV
169
+ // for the DIM dimension.
170
+ mlir::Value iv = doLoop.getInductionVar ();
171
+ llvm::SmallVector<mlir::Value> indices{inputIndices};
172
+ indices.insert (indices.begin () + dimVal - 1 , iv);
173
+
174
+ mlir::OpBuilder::InsertionGuard guard (builder);
175
+ builder.setInsertionPointToStart (doLoop.getBody ());
176
+ mlir::Value reductionValue = doLoop.getRegionIterArgs ()[0 ];
177
+ fir::IfOp ifOp;
178
+ if (mask) {
179
+ // Make the reduction value update conditional on the value
180
+ // of the mask.
181
+ if (!maskValue) {
182
+ // If the mask is an array, use the elemental and the loop indices
183
+ // to address the proper mask element.
184
+ maskValue = genMaskValue (loc, builder, mask, isPresentPred, indices);
185
+ }
186
+ mlir::Value isUnmasked =
187
+ builder.create <fir::ConvertOp>(loc, builder.getI1Type (), maskValue);
188
+ ifOp = builder.create <fir::IfOp>(loc, elementType, isUnmasked,
189
+ /* withElseRegion=*/ true );
190
+ // In the 'else' block return the current reduction value.
191
+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
192
+ builder.create <fir::ResultOp>(loc, reductionValue);
193
+
194
+ // In the 'then' block do the actual addition.
195
+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
196
+ }
197
+
198
+ hlfir::Entity element = hlfir::getElementAt (loc, builder, array, indices);
199
+ hlfir::Entity elementValue =
200
+ hlfir::loadTrivialScalar (loc, builder, element);
201
+ // NOTE: we can use "Kahan summation" same way as the runtime
202
+ // (e.g. when fast-math is not allowed), but let's start with
203
+ // the simple version.
204
+ reductionValue = genScalarAdd (loc, builder, reductionValue, elementValue);
205
+ builder.create <fir::ResultOp>(loc, reductionValue);
206
+
207
+ if (ifOp) {
208
+ builder.setInsertionPointAfter (ifOp);
209
+ builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
210
+ }
211
+
212
+ return hlfir::Entity{doLoop.getResult (0 )};
213
+ };
214
+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
215
+ loc, builder, elementType, resultShape, {}, genKernel,
216
+ /* isUnordered=*/ true , /* polymorphicMold=*/ nullptr ,
217
+ sum.getResult ().getType ());
218
+
219
+ // it wouldn't be safe to replace block arguments with a different
220
+ // hlfir.expr type. Types can differ due to differing amounts of shape
221
+ // information
222
+ assert (elementalOp.getResult ().getType () == sum.getResult ().getType ());
223
+
224
+ rewriter.replaceOp (sum, elementalOp);
225
+ return mlir::success ();
226
+ }
227
+
228
+ private:
229
+ // Return fir.shape specifying the shape of the result
230
+ // of a SUM reduction with DIM=dimVal. The second return value
231
+ // is the extent of the DIM dimension.
232
+ static std::tuple<mlir::Value, mlir::Value>
233
+ genResultShape (mlir::Location loc, fir::FirOpBuilder &builder,
234
+ hlfir::Entity array, int64_t dimVal) {
235
+ mlir::Value inShape = hlfir::genShape (loc, builder, array);
236
+ llvm::SmallVector<mlir::Value> inExtents =
237
+ hlfir::getExplicitExtentsFromShape (inShape, builder);
238
+ if (inShape.getUses ().empty ())
239
+ inShape.getDefiningOp ()->erase ();
240
+
241
+ mlir::Value dimExtent = inExtents[dimVal - 1 ];
242
+ inExtents.erase (inExtents.begin () + dimVal - 1 );
243
+ return {builder.create <fir::ShapeOp>(loc, inExtents), dimExtent};
244
+ }
245
+
246
+ // Generate the initial value for a SUM reduction with the given
247
+ // data type.
248
+ static mlir::Value genInitValue (mlir::Location loc,
249
+ fir::FirOpBuilder &builder,
250
+ mlir::Type elementType) {
251
+ if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
252
+ const llvm::fltSemantics &sem = ty.getFloatSemantics ();
253
+ return builder.createRealConstant (loc, elementType,
254
+ llvm::APFloat::getZero (sem));
255
+ } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
256
+ mlir::Value initValue = genInitValue (loc, builder, ty.getElementType ());
257
+ return fir::factory::Complex{builder, loc}.createComplex (ty, initValue,
258
+ initValue);
259
+ } else if (mlir::isa<mlir::IntegerType>(elementType)) {
260
+ return builder.createIntegerConstant (loc, elementType, 0 );
261
+ }
262
+
263
+ llvm_unreachable (" unsupported SUM reduction type" );
264
+ }
265
+
266
+ // Generate scalar addition of the two values (of the same data type).
267
+ static mlir::Value genScalarAdd (mlir::Location loc,
268
+ fir::FirOpBuilder &builder,
269
+ mlir::Value value1, mlir::Value value2) {
270
+ mlir::Type ty = value1.getType ();
271
+ assert (ty == value2.getType () && " reduction values' types do not match" );
272
+ if (mlir::isa<mlir::FloatType>(ty))
273
+ return builder.create <mlir::arith::AddFOp>(loc, value1, value2);
274
+ else if (mlir::isa<mlir::ComplexType>(ty))
275
+ return builder.create <fir::AddcOp>(loc, value1, value2);
276
+ else if (mlir::isa<mlir::IntegerType>(ty))
277
+ return builder.create <mlir::arith::AddIOp>(loc, value1, value2);
278
+
279
+ llvm_unreachable (" unsupported SUM reduction type" );
280
+ }
281
+
282
+ static mlir::Value genMaskValue (mlir::Location loc,
283
+ fir::FirOpBuilder &builder, mlir::Value mask,
284
+ mlir::Value isPresentPred,
285
+ mlir::ValueRange indices) {
286
+ mlir::OpBuilder::InsertionGuard guard (builder);
287
+ fir::IfOp ifOp;
288
+ mlir::Type maskType =
289
+ hlfir::getFortranElementType (fir::unwrapPassByRefType (mask.getType ()));
290
+ if (isPresentPred) {
291
+ ifOp = builder.create <fir::IfOp>(loc, maskType, isPresentPred,
292
+ /* withElseRegion=*/ true );
293
+
294
+ // Use 'true', if the mask is not present.
295
+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
296
+ mlir::Value trueValue = builder.createBool (loc, true );
297
+ trueValue = builder.createConvert (loc, maskType, trueValue);
298
+ builder.create <fir::ResultOp>(loc, trueValue);
299
+
300
+ // Load the mask value, if the mask is present.
301
+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
302
+ }
303
+
304
+ hlfir::Entity maskVar{mask};
305
+ if (maskVar.isScalar ()) {
306
+ if (mlir::isa<fir::BaseBoxType>(mask.getType ())) {
307
+ // MASK may be a boxed scalar.
308
+ mlir::Value addr = hlfir::genVariableRawAddress (loc, builder, maskVar);
309
+ mask = builder.create <fir::LoadOp>(loc, hlfir::Entity{addr});
310
+ } else {
311
+ mask = hlfir::loadTrivialScalar (loc, builder, maskVar);
312
+ }
313
+ } else {
314
+ // Load from the mask array.
315
+ assert (!indices.empty () && " no indices for addressing the mask array" );
316
+ maskVar = hlfir::getElementAt (loc, builder, maskVar, indices);
317
+ mask = hlfir::loadTrivialScalar (loc, builder, maskVar);
318
+ }
319
+
320
+ if (!isPresentPred)
321
+ return mask;
322
+
323
+ builder.create <fir::ResultOp>(loc, mask);
324
+ return ifOp.getResult (0 );
325
+ }
326
+ };
327
+
93
328
class SimplifyHLFIRIntrinsics
94
329
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
95
330
public:
96
331
void runOnOperation () override {
97
332
mlir::MLIRContext *context = &getContext ();
98
333
mlir::RewritePatternSet patterns (context);
99
334
patterns.insert <TransposeAsElementalConversion>(context);
335
+ patterns.insert <SumAsElementalConversion>(context);
100
336
mlir::ConversionTarget target (*context);
101
337
// don't transform transpose of polymorphic arrays (not currently supported
102
338
// by hlfir.elemental)
@@ -105,6 +341,24 @@ class SimplifyHLFIRIntrinsics
105
341
return mlir::cast<hlfir::ExprType>(transpose.getType ())
106
342
.isPolymorphic ();
107
343
});
344
+ // Handle only SUM(DIM=CONSTANT) case for now.
345
+ // It may be beneficial to expand the non-DIM case as well.
346
+ // E.g. when the input array is an elemental array expression,
347
+ // expanding the SUM into a total reduction loop nest
348
+ // would avoid creating a temporary for the elemental array expression.
349
+ target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
350
+ if (mlir::Value dim = sum.getDim ()) {
351
+ if (fir::getIntIfConstant (dim)) {
352
+ if (!fir::isa_trivial (sum.getType ())) {
353
+ // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
354
+ // It is only legal when X is 1, and it should probably be
355
+ // canonicalized into SUM(a).
356
+ return false ;
357
+ }
358
+ }
359
+ }
360
+ return true ;
361
+ });
108
362
target.markUnknownOpDynamicallyLegal (
109
363
[](mlir::Operation *) { return true ; });
110
364
if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
0 commit comments