Skip to content

Commit 2f58864

Browse files
authored
[flang][OpenMP] Extend do concurrent mapping to device. #50
For simple loops, we can now choose to map `do concurrent` to either the host (i.e. `omp parallel do`) or the device (i.e. `omp target teams distribute parallel do`). In order to use this from `flang-new`, you can pass: `-fdo-concurrent-parallel=[none|host|device]`.
2 parents a92e557 + 3bb1152 commit 2f58864

File tree

20 files changed

+670
-248
lines changed

20 files changed

+670
-248
lines changed

flang/lib/Lower/OpenMP/Utils.h renamed to flang/include/flang/Lower/OpenMP/Utils.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct OmpMapMemberIndicesData {
5959
};
6060

6161
mlir::omp::MapInfoOp
62-
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
62+
createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc,
6363
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
6464
mlir::ArrayRef<mlir::Value> bounds,
6565
mlir::ArrayRef<mlir::Value> members,
@@ -102,6 +102,15 @@ void genObjectList(const ObjectList &objects,
102102
Fortran::lower::AbstractConverter &converter,
103103
llvm::SmallVectorImpl<mlir::Value> &operands);
104104

105+
// TODO: consider moving this to the `omp.loop_nest` op. Would be something like
106+
// this:
107+
//
108+
// ```
109+
// mlir::Value LoopNestOp::calculateTripCount(mlir::OpBuilder &builder,
110+
// mlir::OpBuilder::InsertPoint ip)
111+
// ```
112+
mlir::Value calculateTripCount(fir::FirOpBuilder &builder, mlir::Location loc,
113+
const mlir::omp::CollapseClauseOps &ops);
105114
} // namespace omp
106115
} // namespace lower
107116
} // namespace Fortran

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace fir {
3838
#define GEN_PASS_DECL_ARRAYVALUECOPY
3939
#define GEN_PASS_DECL_CHARACTERCONVERSION
4040
#define GEN_PASS_DECL_CFGCONVERSION
41+
#define GEN_PASS_DECL_DOCONCURRENTCONVERSIONPASS
4142
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION
4243
#define GEN_PASS_DECL_MEMREFDATAFLOWOPT
4344
#define GEN_PASS_DECL_SIMPLIFYINTRINSICS
@@ -88,7 +89,7 @@ createFunctionAttrPass(FunctionAttrTypes &functionAttr, bool noInfsFPMath,
8889
bool noNaNsFPMath, bool approxFuncFPMath,
8990
bool noSignedZerosFPMath, bool unsafeFPMath);
9091

91-
std::unique_ptr<mlir::Pass> createDoConcurrentConversionPass();
92+
std::unique_ptr<mlir::Pass> createDoConcurrentConversionPass(bool mapToDevice);
9293

9394
void populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
9495
bool forceLoopToExecuteOnce = false);

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,12 @@ def DoConcurrentConversionPass : Pass<"fopenmp-do-concurrent-conversion", "mlir:
416416
target.
417417
}];
418418

419-
let constructor = "::fir::createDoConcurrentConversionPass()";
420419
let dependentDialects = ["mlir::omp::OpenMPDialect"];
420+
421+
let options = [
422+
Option<"mapTo", "map-to", "std::string", "",
423+
"Try to map `do concurrent` loops to OpenMP (on host or device)">,
424+
];
421425
}
422426

423427
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

flang/include/flang/Tools/CLOptions.inc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,9 @@ inline void createHLFIRToFIRPassPipeline(
324324
pm.addPass(hlfir::createConvertHLFIRtoFIRPass());
325325
}
326326

327+
using DoConcurrentMappingKind =
328+
Fortran::frontend::CodeGenOptions::DoConcurrentMappingKind;
329+
327330
/// Create a pass pipeline for handling certain OpenMP transformations needed
328331
/// prior to FIR lowering.
329332
///
@@ -333,8 +336,12 @@ inline void createHLFIRToFIRPassPipeline(
333336
/// \param pm - MLIR pass manager that will hold the pipeline definition.
334337
/// \param isTargetDevice - Whether code is being generated for a target device
335338
/// rather than the host device.
336-
inline void createOpenMPFIRPassPipeline(
337-
mlir::PassManager &pm, bool isTargetDevice) {
339+
inline void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
340+
bool isTargetDevice, DoConcurrentMappingKind doConcurrentMappingKind) {
341+
if (doConcurrentMappingKind != DoConcurrentMappingKind::DCMK_None)
342+
pm.addPass(fir::createDoConcurrentConversionPass(
343+
doConcurrentMappingKind == DoConcurrentMappingKind::DCMK_Device));
344+
338345
pm.addPass(fir::createOMPMapInfoFinalizationPass());
339346
pm.addPass(fir::createOMPMarkDeclareTargetPass());
340347
if (isTargetDevice)

flang/lib/Frontend/FrontendActions.cpp

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -320,41 +320,34 @@ bool CodeGenAction::beginSourceFileAction() {
320320
// Add OpenMP-related passes
321321
// WARNING: These passes must be run immediately after the lowering to ensure
322322
// that the FIR is correct with respect to OpenMP operations/attributes.
323-
bool isOpenMPEnabled = ci.getInvocation().getFrontendOpts().features.IsEnabled(
323+
bool isOpenMPEnabled =
324+
ci.getInvocation().getFrontendOpts().features.IsEnabled(
324325
Fortran::common::LanguageFeature::OpenMP);
326+
327+
using DoConcurrentMappingKind =
328+
Fortran::frontend::CodeGenOptions::DoConcurrentMappingKind;
329+
DoConcurrentMappingKind doConcurrentMappingKind =
330+
ci.getInvocation().getCodeGenOpts().getDoConcurrentMapping();
331+
332+
if (doConcurrentMappingKind != DoConcurrentMappingKind::DCMK_None &&
333+
!isOpenMPEnabled) {
334+
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
335+
clang::DiagnosticsEngine::Warning,
336+
"lowering `do concurrent` loops to OpenMP is only supported if "
337+
"OpenMP is enabled");
338+
ci.getDiagnostics().Report(diagID);
339+
}
340+
325341
if (isOpenMPEnabled) {
326342
bool isDevice = false;
327343
if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(
328344
mlirModule->getOperation()))
329345
isDevice = offloadMod.getIsTargetDevice();
346+
330347
// WARNING: This pipeline must be run immediately after the lowering to
331348
// ensure that the FIR is correct with respect to OpenMP operations/
332349
// attributes.
333-
fir::createOpenMPFIRPassPipeline(pm, isDevice);
334-
}
335-
336-
using DoConcurrentMappingKind =
337-
Fortran::frontend::CodeGenOptions::DoConcurrentMappingKind;
338-
DoConcurrentMappingKind selectedKind = ci.getInvocation().getCodeGenOpts().getDoConcurrentMapping();
339-
if (selectedKind != DoConcurrentMappingKind::DCMK_None) {
340-
if (!isOpenMPEnabled) {
341-
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
342-
clang::DiagnosticsEngine::Warning,
343-
"lowering `do concurrent` loops to OpenMP is only supported if "
344-
"OpenMP is enabled");
345-
ci.getDiagnostics().Report(diagID);
346-
} else {
347-
bool mapToDevice = selectedKind == DoConcurrentMappingKind::DCMK_Device;
348-
349-
if (mapToDevice) {
350-
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
351-
clang::DiagnosticsEngine::Warning,
352-
"TODO: lowering `do concurrent` loops to OpenMP device is not "
353-
"supported yet");
354-
ci.getDiagnostics().Report(diagID);
355-
} else
356-
pm.addPass(fir::createDoConcurrentConversionPass());
357-
}
350+
fir::createOpenMPFIRPassPipeline(pm, isDevice, doConcurrentMappingKind);
358351
}
359352

360353
pm.enableVerifier(/*verifyPasses=*/true);

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "ClauseProcessor.h"
14-
#include "Clauses.h"
1514

15+
#include "flang/Lower/OpenMP/Clauses.h"
1616
#include "flang/Lower/PFTBuilder.h"
1717
#include "flang/Parser/tools.h"
1818
#include "flang/Semantics/tools.h"

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H
1313
#define FORTRAN_LOWER_CLAUASEPROCESSOR_H
1414

15-
#include "Clauses.h"
1615
#include "DirectivesCommon.h"
1716
#include "ReductionProcessor.h"
18-
#include "Utils.h"
1917
#include "flang/Lower/AbstractConverter.h"
2018
#include "flang/Lower/Bridge.h"
19+
#include "flang/Lower/OpenMP/Clauses.h"
20+
#include "flang/Lower/OpenMP/Utils.h"
2121
#include "flang/Optimizer/Builder/Todo.h"
2222
#include "flang/Parser/dump-parse-tree.h"
2323
#include "flang/Parser/parse-tree.h"

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "Clauses.h"
9+
#include "flang/Lower/OpenMP/Clauses.h"
1010

1111
#include "flang/Common/idioms.h"
1212
#include "flang/Evaluate/expression.h"

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
#include "DataSharingProcessor.h"
1414

15-
#include "Utils.h"
15+
#include "flang/Lower/OpenMP/Utils.h"
1616
#include "flang/Lower/PFTBuilder.h"
1717
#include "flang/Lower/SymbolMap.h"
1818
#include "flang/Optimizer/Builder/HLFIRTools.h"

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
#ifndef FORTRAN_LOWER_DATASHARINGPROCESSOR_H
1313
#define FORTRAN_LOWER_DATASHARINGPROCESSOR_H
1414

15-
#include "Clauses.h"
1615
#include "flang/Lower/AbstractConverter.h"
1716
#include "flang/Lower/OpenMP.h"
17+
#include "flang/Lower/OpenMP/Clauses.h"
1818
#include "flang/Optimizer/Builder/FIRBuilder.h"
1919
#include "flang/Parser/parse-tree.h"
2020
#include "flang/Semantics/symbol.h"

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 4 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
#include "flang/Lower/OpenMP.h"
1414

1515
#include "ClauseProcessor.h"
16-
#include "Clauses.h"
1716
#include "DataSharingProcessor.h"
1817
#include "DirectivesCommon.h"
1918
#include "ReductionProcessor.h"
20-
#include "Utils.h"
2119
#include "flang/Common/idioms.h"
2220
#include "flang/Lower/Bridge.h"
2321
#include "flang/Lower/ConvertExpr.h"
2422
#include "flang/Lower/ConvertVariable.h"
23+
#include "flang/Lower/OpenMP/Clauses.h"
24+
#include "flang/Lower/OpenMP/Utils.h"
2525
#include "flang/Lower/StatementContext.h"
2626
#include "flang/Lower/SymbolMap.h"
2727
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -280,84 +280,6 @@ static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
280280
}
281281
}
282282

283-
static mlir::Value
284-
calculateTripCount(Fortran::lower::AbstractConverter &converter,
285-
mlir::Location loc,
286-
const mlir::omp::CollapseClauseOps &ops) {
287-
using namespace mlir::arith;
288-
assert(ops.loopLBVar.size() == ops.loopUBVar.size() &&
289-
ops.loopLBVar.size() == ops.loopStepVar.size() &&
290-
!ops.loopLBVar.empty() && "Invalid bounds or step");
291-
292-
fir::FirOpBuilder &b = converter.getFirOpBuilder();
293-
294-
// Get the bit width of an integer-like type.
295-
auto widthOf = [](mlir::Type ty) -> unsigned {
296-
if (mlir::isa<mlir::IndexType>(ty)) {
297-
return mlir::IndexType::kInternalStorageBitWidth;
298-
}
299-
if (auto tyInt = mlir::dyn_cast<mlir::IntegerType>(ty)) {
300-
return tyInt.getWidth();
301-
}
302-
llvm_unreachable("Unexpected type");
303-
};
304-
305-
// For a type that is either IntegerType or IndexType, return the
306-
// equivalent IntegerType. In the former case this is a no-op.
307-
auto asIntTy = [&](mlir::Type ty) -> mlir::IntegerType {
308-
if (ty.isIndex()) {
309-
return mlir::IntegerType::get(ty.getContext(), widthOf(ty));
310-
}
311-
assert(ty.isIntOrIndex() && "Unexpected type");
312-
return mlir::cast<mlir::IntegerType>(ty);
313-
};
314-
315-
// For two given values, establish a common signless IntegerType
316-
// that can represent any value of type of x and of type of y,
317-
// and return the pair of x, y converted to the new type.
318-
auto unifyToSignless =
319-
[&](fir::FirOpBuilder &b, mlir::Value x,
320-
mlir::Value y) -> std::pair<mlir::Value, mlir::Value> {
321-
auto tyX = asIntTy(x.getType()), tyY = asIntTy(y.getType());
322-
unsigned width = std::max(widthOf(tyX), widthOf(tyY));
323-
auto wideTy = mlir::IntegerType::get(b.getContext(), width,
324-
mlir::IntegerType::Signless);
325-
return std::make_pair(b.createConvert(loc, wideTy, x),
326-
b.createConvert(loc, wideTy, y));
327-
};
328-
329-
// Start with signless i32 by default.
330-
auto tripCount = b.createIntegerConstant(loc, b.getI32Type(), 1);
331-
332-
for (auto [origLb, origUb, origStep] :
333-
llvm::zip(ops.loopLBVar, ops.loopUBVar, ops.loopStepVar)) {
334-
auto tmpS0 = b.createIntegerConstant(loc, origStep.getType(), 0);
335-
auto [step, step0] = unifyToSignless(b, origStep, tmpS0);
336-
auto reverseCond = b.create<CmpIOp>(loc, CmpIPredicate::slt, step, step0);
337-
auto negStep = b.create<SubIOp>(loc, step0, step);
338-
mlir::Value absStep = b.create<SelectOp>(loc, reverseCond, negStep, step);
339-
340-
auto [lb, ub] = unifyToSignless(b, origLb, origUb);
341-
auto start = b.create<SelectOp>(loc, reverseCond, ub, lb);
342-
auto end = b.create<SelectOp>(loc, reverseCond, lb, ub);
343-
344-
mlir::Value range = b.create<SubIOp>(loc, end, start);
345-
auto rangeCond = b.create<CmpIOp>(loc, CmpIPredicate::slt, end, start);
346-
std::tie(range, absStep) = unifyToSignless(b, range, absStep);
347-
// numSteps = (range /u absStep) + 1
348-
auto numSteps =
349-
b.create<AddIOp>(loc, b.create<DivUIOp>(loc, range, absStep),
350-
b.createIntegerConstant(loc, range.getType(), 1));
351-
352-
auto trip0 = b.createIntegerConstant(loc, numSteps.getType(), 0);
353-
auto loopTripCount = b.create<SelectOp>(loc, rangeCond, trip0, numSteps);
354-
auto [totalTC, thisTC] = unifyToSignless(b, tripCount, loopTripCount);
355-
tripCount = b.create<MulIOp>(loc, totalTC, thisTC);
356-
}
357-
358-
return tripCount;
359-
}
360-
361283
static mlir::Operation *
362284
createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
363285
mlir::Location loc, mlir::Value indexVal,
@@ -1574,8 +1496,8 @@ genLoopNestOp(Fortran::lower::AbstractConverter &converter,
15741496
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
15751497
ClauseProcessor cp(converter, semaCtx, clauses);
15761498
cp.processCollapse(loc, eval, collapseClauseOps, iv);
1577-
targetOp.getTripCountMutable().assign(
1578-
calculateTripCount(converter, loc, collapseClauseOps));
1499+
targetOp.getTripCountMutable().assign(calculateTripCount(
1500+
converter.getFirOpBuilder(), loc, collapseClauseOps));
15791501
}
15801502
return loopNestOp;
15811503
}

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
1414
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
1515

16-
#include "Clauses.h"
16+
#include "flang/Lower/OpenMP/Clauses.h"
1717
#include "flang/Optimizer/Builder/FIRBuilder.h"
1818
#include "flang/Optimizer/Dialect/FIRType.h"
1919
#include "flang/Semantics/symbol.h"

0 commit comments

Comments
 (0)