Skip to content

Commit db09014

Browse files
ergawyagozillon
andauthored
[flang][OpenMP] Implicitly map allocatable record fields (#117867)
This is a starting PR to implicitly map allocatable record fields. This PR contains the following changes: 1. Re-purposes some of the utils used in `Lower/OpenMP.cpp` so that these utils work on the `mlir::Value` level rather than the `semantics::Symbol` level. This takes one step towards to enabling MLIR passes to more easily do some lowering themselves (e.g. creating `omp.map.bounds` ops for implicitely caputured data like this PR does). 2. Adds support for implicitely capturing and mapping allocatable fields in record types. There is quite some distant to still cover to have full support for this. I added a number of todos to guide further development. Co-authored-by: Andrew Gozillon <[email protected]> Co-authored-by: Andrew Gozillon <[email protected]>
1 parent 1a70420 commit db09014

File tree

12 files changed

+412
-36
lines changed

12 files changed

+412
-36
lines changed

flang/lib/Lower/DirectivesCommon.h renamed to flang/include/flang/Lower/DirectivesCommon.h

+33-17
Original file line numberDiff line numberDiff line change
@@ -609,32 +609,22 @@ void createEmptyRegionBlocks(
609609
}
610610
}
611611

612-
inline AddrAndBoundsInfo
613-
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
614-
fir::FirOpBuilder &builder,
615-
Fortran::lower::SymbolRef sym, mlir::Location loc) {
616-
mlir::Value symAddr = converter.getSymbolAddress(sym);
612+
inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
613+
mlir::Value symAddr,
614+
bool isOptional,
615+
mlir::Location loc) {
617616
mlir::Value rawInput = symAddr;
618617
if (auto declareOp =
619618
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
620619
symAddr = declareOp.getResults()[0];
621620
rawInput = declareOp.getResults()[1];
622621
}
623622

624-
// TODO: Might need revisiting to handle for non-shared clauses
625-
if (!symAddr) {
626-
if (const auto *details =
627-
sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
628-
symAddr = converter.getSymbolAddress(details->symbol());
629-
rawInput = symAddr;
630-
}
631-
}
632-
633623
if (!symAddr)
634624
llvm::report_fatal_error("could not retrieve symbol address");
635625

636626
mlir::Value isPresent;
637-
if (Fortran::semantics::IsOptional(sym))
627+
if (isOptional)
638628
isPresent =
639629
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
640630

@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
648638
// all address/dimension retrievals. For Fortran optional though, leave
649639
// the load generation for later so it can be done in the appropriate
650640
// if branches.
651-
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
652-
!Fortran::semantics::IsOptional(sym)) {
641+
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
653642
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
654643
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
655644
}
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
659648
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
660649
}
661650

651+
inline AddrAndBoundsInfo
652+
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
653+
fir::FirOpBuilder &builder,
654+
Fortran::lower::SymbolRef sym, mlir::Location loc) {
655+
return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
656+
Fortran::semantics::IsOptional(sym), loc);
657+
}
658+
662659
template <typename BoundsOp, typename BoundsType>
663660
llvm::SmallVector<mlir::Value>
664661
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
12241221

12251222
return info;
12261223
}
1224+
1225+
template <typename BoundsOp, typename BoundsType>
1226+
llvm::SmallVector<mlir::Value>
1227+
genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
1228+
fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
1229+
mlir::Location loc) {
1230+
llvm::SmallVector<mlir::Value> bounds;
1231+
1232+
mlir::Value baseOp = info.rawInput;
1233+
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
1234+
bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
1235+
dataExv, info);
1236+
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
1237+
bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
1238+
builder, loc, dataExv, dataExvIsAssumedSize);
1239+
}
1240+
1241+
return bounds;
1242+
}
12271243
} // namespace lower
12281244
} // namespace Fortran
12291245

flang/lib/Lower/Bridge.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "flang/Lower/Bridge.h"
14-
#include "DirectivesCommon.h"
14+
1515
#include "flang/Common/Version.h"
1616
#include "flang/Lower/Allocatable.h"
1717
#include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
2222
#include "flang/Lower/ConvertType.h"
2323
#include "flang/Lower/ConvertVariable.h"
2424
#include "flang/Lower/Cuda.h"
25+
#include "flang/Lower/DirectivesCommon.h"
2526
#include "flang/Lower/HostAssociations.h"
2627
#include "flang/Lower/IO.h"
2728
#include "flang/Lower/IterationSpace.h"

flang/lib/Lower/OpenACC.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "flang/Lower/OpenACC.h"
14-
#include "DirectivesCommon.h"
14+
1515
#include "flang/Common/idioms.h"
1616
#include "flang/Lower/Bridge.h"
1717
#include "flang/Lower/ConvertType.h"
18+
#include "flang/Lower/DirectivesCommon.h"
1819
#include "flang/Lower/Mangler.h"
1920
#include "flang/Lower/PFTBuilder.h"
2021
#include "flang/Lower/StatementContext.h"

flang/lib/Lower/OpenMP/ClauseProcessor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
1414

1515
#include "Clauses.h"
16-
#include "DirectivesCommon.h"
1716
#include "ReductionProcessor.h"
1817
#include "Utils.h"
1918
#include "flang/Lower/AbstractConverter.h"
2019
#include "flang/Lower/Bridge.h"
20+
#include "flang/Lower/DirectivesCommon.h"
2121
#include "flang/Optimizer/Builder/Todo.h"
2222
#include "flang/Parser/dump-parse-tree.h"
2323
#include "flang/Parser/parse-tree.h"

flang/lib/Lower/OpenMP/OpenMP.cpp

+8-15
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
#include "Clauses.h"
1717
#include "DataSharingProcessor.h"
1818
#include "Decomposer.h"
19-
#include "DirectivesCommon.h"
2019
#include "ReductionProcessor.h"
2120
#include "Utils.h"
2221
#include "flang/Common/OpenMP-utils.h"
2322
#include "flang/Common/idioms.h"
2423
#include "flang/Lower/Bridge.h"
2524
#include "flang/Lower/ConvertExpr.h"
2625
#include "flang/Lower/ConvertVariable.h"
26+
#include "flang/Lower/DirectivesCommon.h"
2727
#include "flang/Lower/StatementContext.h"
2828
#include "flang/Lower/SymbolMap.h"
2929
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
17351735
if (const auto *details =
17361736
sym.template detailsIf<semantics::HostAssocDetails>())
17371737
converter.copySymbolBinding(details->symbol(), sym);
1738-
llvm::SmallVector<mlir::Value> bounds;
17391738
std::stringstream name;
17401739
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
17411740
name << sym.name().ToString();
17421741

17431742
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
17441743
converter, firOpBuilder, sym, converter.getCurrentLocation());
1745-
mlir::Value baseOp = info.rawInput;
1746-
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
1747-
bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
1748-
mlir::omp::MapBoundsType>(
1749-
firOpBuilder, converter.getCurrentLocation(), dataExv, info);
1750-
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
1751-
bool dataExvIsAssumedSize =
1752-
semantics::IsAssumedSizeArray(sym.GetUltimate());
1753-
bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
1754-
mlir::omp::MapBoundsType>(
1755-
firOpBuilder, converter.getCurrentLocation(), dataExv,
1756-
dataExvIsAssumedSize);
1757-
}
1744+
llvm::SmallVector<mlir::Value> bounds =
1745+
lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
1746+
mlir::omp::MapBoundsType>(
1747+
firOpBuilder, info, dataExv,
1748+
semantics::IsAssumedSizeArray(sym.GetUltimate()),
1749+
converter.getCurrentLocation());
17581750

17591751
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
17601752
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
17611753
mlir::omp::VariableCaptureKind captureKind =
17621754
mlir::omp::VariableCaptureKind::ByRef;
17631755

1756+
mlir::Value baseOp = info.rawInput;
17641757
mlir::Type eleType = baseOp.getType();
17651758
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
17661759
eleType = refType.getElementType();

flang/lib/Lower/OpenMP/Utils.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
#include "Utils.h"
1414

1515
#include "Clauses.h"
16-
#include <DirectivesCommon.h>
1716

1817
#include <flang/Lower/AbstractConverter.h>
1918
#include <flang/Lower/ConvertType.h>
19+
#include <flang/Lower/DirectivesCommon.h>
2020
#include <flang/Lower/PFTBuilder.h>
2121
#include <flang/Optimizer/Builder/FIRBuilder.h>
2222
#include <flang/Optimizer/Builder/Todo.h>

flang/lib/Optimizer/OpenMP/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
1212
FIRDialect
1313
HLFIROpsIncGen
1414
FlangOpenMPPassesIncGen
15+
${dialect_libs}
1516

1617
LINK_LIBS
1718
FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
2728
MLIRIR
2829
MLIRPass
2930
MLIRTransformUtils
31+
${dialect_libs}
3032
)

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

+158
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,14 @@
2424
/// indirectly via a parent object.
2525
//===----------------------------------------------------------------------===//
2626

27+
#include "flang/Lower/DirectivesCommon.h"
2728
#include "flang/Optimizer/Builder/FIRBuilder.h"
29+
#include "flang/Optimizer/Builder/HLFIRTools.h"
2830
#include "flang/Optimizer/Dialect/FIRType.h"
2931
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
32+
#include "flang/Optimizer/HLFIR/HLFIROps.h"
3033
#include "flang/Optimizer/OpenMP/Passes.h"
34+
#include "mlir/Analysis/SliceAnalysis.h"
3135
#include "mlir/Dialect/Func/IR/FuncOps.h"
3236
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3337
#include "mlir/IR/BuiltinDialect.h"
@@ -486,6 +490,160 @@ class MapInfoFinalizationPass
486490
// iterations from previous function scopes.
487491
localBoxAllocas.clear();
488492

493+
// First, walk `omp.map.info` ops to see if any record members should be
494+
// implicitly mapped.
495+
func->walk([&](mlir::omp::MapInfoOp op) {
496+
mlir::Type underlyingType =
497+
fir::unwrapRefType(op.getVarPtr().getType());
498+
499+
// TODO Test with and support more complicated cases; like arrays for
500+
// records, for example.
501+
if (!fir::isRecordWithAllocatableMember(underlyingType))
502+
return mlir::WalkResult::advance();
503+
504+
// TODO For now, only consider `omp.target` ops. Other ops that support
505+
// `map` clauses will follow later.
506+
mlir::omp::TargetOp target =
507+
mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
508+
getFirstTargetUser(op));
509+
510+
if (!target)
511+
return mlir::WalkResult::advance();
512+
513+
auto mapClauseOwner =
514+
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
515+
516+
int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
517+
assert(mapVarIdx >= 0 &&
518+
mapVarIdx <
519+
static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
520+
521+
auto argIface =
522+
llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
523+
// TODO How should `map` block argument that correspond to: `private`,
524+
// `use_device_addr`, `use_device_ptr`, be handled?
525+
mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
526+
llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
527+
mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
528+
529+
mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
530+
// TODO Support coordinate_of ops.
531+
//
532+
// TODO Support call ops by recursively examining the forward slice of
533+
// the corresponding parameter to the field in the called function.
534+
return !mlir::isa<hlfir::DesignateOp>(sliceOp);
535+
});
536+
537+
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
538+
llvm::SmallVector<mlir::Value> newMapOpsForFields;
539+
llvm::SmallVector<int64_t> fieldIndicies;
540+
541+
for (auto fieldMemTyPair : recordType.getTypeList()) {
542+
auto &field = fieldMemTyPair.first;
543+
auto memTy = fieldMemTyPair.second;
544+
545+
bool shouldMapField =
546+
llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
547+
if (!fir::isAllocatableType(memTy))
548+
return false;
549+
550+
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
551+
if (!designateOp)
552+
return false;
553+
554+
return designateOp.getComponent() &&
555+
designateOp.getComponent()->strref() == field;
556+
}) != mapVarForwardSlice.end();
557+
558+
// TODO Handle recursive record types. Adapting
559+
// `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
560+
// entities might be helpful here.
561+
562+
if (!shouldMapField)
563+
continue;
564+
565+
int64_t fieldIdx = recordType.getFieldIndex(field);
566+
bool alreadyMapped = [&]() {
567+
if (op.getMembersIndexAttr())
568+
for (auto indexList : op.getMembersIndexAttr()) {
569+
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
570+
if (indexListAttr.size() == 1 &&
571+
mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
572+
fieldIdx)
573+
return true;
574+
}
575+
576+
return false;
577+
}();
578+
579+
if (alreadyMapped)
580+
continue;
581+
582+
builder.setInsertionPoint(op);
583+
mlir::Value fieldIdxVal = builder.createIntegerConstant(
584+
op.getLoc(), mlir::IndexType::get(builder.getContext()),
585+
fieldIdx);
586+
auto fieldCoord = builder.create<fir::CoordinateOp>(
587+
op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
588+
fieldIdxVal);
589+
Fortran::lower::AddrAndBoundsInfo info =
590+
Fortran::lower::getDataOperandBaseAddr(
591+
builder, fieldCoord, /*isOptional=*/false, op.getLoc());
592+
llvm::SmallVector<mlir::Value> bounds =
593+
Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
594+
mlir::omp::MapBoundsType>(
595+
builder, info,
596+
hlfir::translateToExtendedValue(op.getLoc(), builder,
597+
hlfir::Entity{fieldCoord})
598+
.first,
599+
/*dataExvIsAssumedSize=*/false, op.getLoc());
600+
601+
mlir::omp::MapInfoOp fieldMapOp =
602+
builder.create<mlir::omp::MapInfoOp>(
603+
op.getLoc(), fieldCoord.getResult().getType(),
604+
fieldCoord.getResult(),
605+
mlir::TypeAttr::get(
606+
fir::unwrapRefType(fieldCoord.getResult().getType())),
607+
/*varPtrPtr=*/mlir::Value{},
608+
/*members=*/mlir::ValueRange{},
609+
/*members_index=*/mlir::ArrayAttr{},
610+
/*bounds=*/bounds, op.getMapTypeAttr(),
611+
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
612+
mlir::omp::VariableCaptureKind::ByRef),
613+
builder.getStringAttr(op.getNameAttr().strref() + "." +
614+
field + ".implicit_map"),
615+
/*partial_map=*/builder.getBoolAttr(false));
616+
newMapOpsForFields.emplace_back(fieldMapOp);
617+
fieldIndicies.emplace_back(fieldIdx);
618+
}
619+
620+
if (newMapOpsForFields.empty())
621+
return mlir::WalkResult::advance();
622+
623+
op.getMembersMutable().append(newMapOpsForFields);
624+
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
625+
mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
626+
627+
if (oldMembersIdxAttr)
628+
for (mlir::Attribute indexList : oldMembersIdxAttr) {
629+
llvm::SmallVector<int64_t> listVec;
630+
631+
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
632+
listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
633+
634+
newMemberIndices.emplace_back(std::move(listVec));
635+
}
636+
637+
for (int64_t newFieldIdx : fieldIndicies)
638+
newMemberIndices.emplace_back(
639+
llvm::SmallVector<int64_t>(1, newFieldIdx));
640+
641+
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
642+
op.setPartialMap(true);
643+
644+
return mlir::WalkResult::advance();
645+
});
646+
489647
func->walk([&](mlir::omp::MapInfoOp op) {
490648
// TODO: Currently only supports a single user for the MapInfoOp. This
491649
// is fine for the moment, as the Fortran frontend will generate a

0 commit comments

Comments
 (0)