Skip to content

[MLIR][OpenMP] Normalize handling of entry block arguments #109808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,17 +472,26 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
/// \param [in] infoAccessor - for a private variable, this returns the
/// data we want to merge: type or location.
/// \param [out] allRegionArgsInfo - the merged list of region info.
/// \param [in] addBeforePrivate - `true` if the passed information goes before
/// private information.
template <typename OMPOp, typename InfoTy>
static void
mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList,
llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo,
bool addBeforePrivate) {
mlir::OperandRange privateVars = op.getPrivateVars();

llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
[](InfoTy i) { return i; });
if (addBeforePrivate)
llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
[](InfoTy i) { return i; });

llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo),
infoAccessor);

if (!addBeforePrivate)
llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
[](InfoTy i) { return i; });
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -868,12 +877,12 @@ static void genBodyOfTargetOp(
mergePrivateVarsInfo(targetOp, mapSymTypes,
llvm::function_ref<mlir::Type(mlir::Value)>{
[](mlir::Value v) { return v.getType(); }},
allRegionArgTypes);
allRegionArgTypes, /*addBeforePrivate=*/true);

mergePrivateVarsInfo(targetOp, mapSymLocs,
llvm::function_ref<mlir::Location(mlir::Value)>{
[](mlir::Value v) { return v.getLoc(); }},
allRegionArgLocs);
allRegionArgLocs, /*addBeforePrivate=*/true);

mlir::Block *regionBlock = firOpBuilder.createBlock(
&region, {}, allRegionArgTypes, allRegionArgLocs);
Expand Down Expand Up @@ -1478,21 +1487,21 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mergePrivateVarsInfo(parallelOp, reductionTypes,
llvm::function_ref<mlir::Type(mlir::Value)>{
[](mlir::Value v) { return v.getType(); }},
allRegionArgTypes);
allRegionArgTypes, /*addBeforePrivate=*/false);

llvm::SmallVector<mlir::Location> allRegionArgLocs;
mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs),
llvm::function_ref<mlir::Location(mlir::Value)>{
[](mlir::Value v) { return v.getLoc(); }},
allRegionArgLocs);
allRegionArgLocs, /*addBeforePrivate=*/false);

mlir::Region &region = parallelOp.getRegion();
firOpBuilder.createBlock(&region, /*insertPt=*/{}, allRegionArgTypes,
allRegionArgLocs);

llvm::SmallVector<const semantics::Symbol *> allSymbols(reductionSyms);
allSymbols.append(dsp->getDelayedPrivSymbols().begin(),
dsp->getDelayedPrivSymbols().end());
llvm::SmallVector<const semantics::Symbol *> allSymbols(
dsp->getDelayedPrivSymbols());
allSymbols.append(reductionSyms.begin(), reductionSyms.end());

unsigned argIdx = 0;
for (const semantics::Symbol *arg : allSymbols) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ subroutine red_and_delayed_private

! CHECK-LABEL: _QPred_and_delayed_private
! CHECK: omp.parallel
! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ subroutine red_and_delayed_private

! CHECK-LABEL: _QPred_and_delayed_private
! CHECK: omp.parallel
! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
39 changes: 28 additions & 11 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ class OpenMP_InReductionClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
ReductionClauseInterface
BlockArgOpenMPOpInterface, ReductionClauseInterface
];

let arguments = (ins
Expand All @@ -472,6 +472,8 @@ class OpenMP_InReductionClauseSkip<
return SmallVector<Value>(getInReductionVars().begin(),
getInReductionVars().end());
}

unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
}];

// Description varies depending on the operation.
Expand Down Expand Up @@ -575,6 +577,8 @@ class OpenMP_MapClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
// Not adding the BlockArgOpenMPOpInterface here because omp.target is the
// only operation defining block arguments for `map` clauses.
MapClauseOwningOpInterface
];

Expand Down Expand Up @@ -923,6 +927,10 @@ class OpenMP_PrivateClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
BlockArgOpenMPOpInterface
];

let arguments = (ins
Variadic<AnyType>:$private_vars,
OptionalAttr<SymbolRefArrayAttr>:$private_syms
Expand All @@ -933,6 +941,10 @@ class OpenMP_PrivateClauseSkip<
custom<PrivateList>($private_vars, type($private_vars), $private_syms) `)`
}];

let extraClassDeclaration = [{
unsigned numPrivateBlockArgs() { return getPrivateVars().size(); }
}];

// TODO: Add description.
}

Expand Down Expand Up @@ -973,7 +985,7 @@ class OpenMP_ReductionClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
ReductionClauseInterface
BlockArgOpenMPOpInterface, ReductionClauseInterface
];

let arguments = (ins
Expand All @@ -991,6 +1003,7 @@ class OpenMP_ReductionClauseSkip<
let extraClassDeclaration = [{
/// Returns the number of reduction variables.
unsigned getNumReductionVars() { return getReductionVars().size(); }
unsigned numReductionBlockArgs() { return getReductionVars().size(); }
}];

// Description varies depending on the operation.
Expand Down Expand Up @@ -1104,7 +1117,7 @@ class OpenMP_TaskReductionClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
ReductionClauseInterface
BlockArgOpenMPOpInterface, ReductionClauseInterface
];

let arguments = (ins
Expand All @@ -1119,6 +1132,18 @@ class OpenMP_TaskReductionClauseSkip<
$task_reduction_byref, $task_reduction_syms) `)`
}];

let extraClassDeclaration = [{
/// Returns the reduction variables.
SmallVector<Value> getReductionVars() {
return SmallVector<Value>(getTaskReductionVars().begin(),
getTaskReductionVars().end());
}

unsigned numTaskReductionBlockArgs() {
return getTaskReductionVars().size();
}
}];

let description = [{
The `task_reduction` clause specifies a reduction among tasks. For each list
item, the number of copies is unspecified. Any copies associated with the
Expand All @@ -1130,14 +1155,6 @@ class OpenMP_TaskReductionClauseSkip<
attribute, and whether the reduction variable should be passed into the
reduction region by value or by reference in `task_reduction_byref`.
}];

let extraClassDeclaration = [{
/// Returns the reduction variables.
SmallVector<Value> getReductionVars() {
return SmallVector<Value>(getTaskReductionVars().begin(),
getTaskReductionVars().end());
}
}];
}

def OpenMP_TaskReductionClause : OpenMP_TaskReductionClauseSkip<>;
Expand Down
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,8 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
//===----------------------------------------------------------------------===//

def TargetOp : OpenMP_Op<"target", traits = [
AttrSizedOperandSegments, IsolatedFromAbove, OutlineableOpenMPOpInterface
AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
OutlineableOpenMPOpInterface
], clauses = [
// TODO: Complete clause list (defaultmap, uses_allocators).
OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
Expand All @@ -1065,6 +1066,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
];

let extraClassDeclaration = [{
unsigned numMapBlockArgs() { return getMapVars().size(); }
}] # clausesExtraClassDeclaration;

let hasVerifier = 1;
}

Expand Down
108 changes: 108 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,114 @@

include "mlir/IR/OpBase.td"

def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
let description = [{
OpenMP operations that define entry block arguments as part of the
representation of its clauses.
}];

let cppNamespace = "::mlir::omp";

let methods = [
// Default-implemented methods to be overriden by the corresponding clauses.
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
return 0;
}]>,
InterfaceMethod<"Get number of block arguments defined by `map`.",
"unsigned", "numMapBlockArgs", (ins), [{}], [{
return 0;
}]>,
InterfaceMethod<"Get number of block arguments defined by `private`.",
"unsigned", "numPrivateBlockArgs", (ins), [{}], [{
return 0;
}]>,
InterfaceMethod<"Get number of block arguments defined by `reduction`.",
"unsigned", "numReductionBlockArgs", (ins), [{}], [{
return 0;
}]>,
InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
"unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
return 0;
}]>,

// Unified access methods for clause-associated entry block arguments.
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
"unsigned", "getInReductionBlockArgsStart", (ins), [{
return 0;
}]>,
InterfaceMethod<"Get start index of block arguments defined by `map`.",
"unsigned", "getMapBlockArgsStart", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getInReductionBlockArgsStart() +
$_op.numInReductionBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `private`.",
"unsigned", "getPrivateBlockArgsStart", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getMapBlockArgsStart() + $_op.numMapBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `reduction`.",
"unsigned", "getReductionBlockArgsStart", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getPrivateBlockArgsStart() + $_op.numPrivateBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `task_reduction`.",
"unsigned", "getTaskReductionBlockArgsStart", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs();
}]>,

InterfaceMethod<"Get block arguments defined by `in_reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getInReductionBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getInReductionBlockArgsStart(), $_op.numInReductionBlockArgs());
}]>,
InterfaceMethod<"Get block arguments defined by `map`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getMapBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getMapBlockArgsStart(), $_op.numMapBlockArgs());
}]>,
InterfaceMethod<"Get block arguments defined by `private`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getPrivateBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getPrivateBlockArgsStart(), $_op.numPrivateBlockArgs());
}]>,
InterfaceMethod<"Get block arguments defined by `reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getReductionBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getReductionBlockArgsStart(), $_op.numReductionBlockArgs());
}]>,
InterfaceMethod<"Get block arguments defined by `task_reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getTaskReductionBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getTaskReductionBlockArgsStart(),
$_op.numTaskReductionBlockArgs());
}]>,
];

let verify = [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
unsigned expectedArgs = iface.numInReductionBlockArgs() +
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
return $_op->emitOpError() << "expected at least " << expectedArgs
<< " entry block argument(s)";
return ::mlir::success();
}];
}

def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
let description = [{
OpenMP operations whose region will be outlined will implement this
Expand Down
34 changes: 17 additions & 17 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,6 @@ static ParseResult parseParallelRegion(
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;

if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
reductionTypes, reductionByref,
reductionSyms, regionPrivateArgs)))
return failure();
}

if (succeeded(parser.parseOptionalKeyword("private"))) {
auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
Expand All @@ -557,6 +550,13 @@ static ParseResult parseParallelRegion(
}
}

if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
reductionTypes, reductionByref,
reductionSyms, regionPrivateArgs)))
return failure();
}

return parser.parseRegion(region, regionPrivateArgs);
}

Expand All @@ -566,18 +566,9 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms, ValueRange privateVars,
TypeRange privateTypes, ArrayAttr privateSyms) {
if (reductionSyms) {
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
reductionTypes, reductionByref, reductionSyms);
}

if (privateSyms) {
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin + reductionVars.size(),
argsBegin + reductionVars.size() +
privateTypes.size());
MutableArrayRef argsSubrange(argsBegin, argsBegin + privateTypes.size());
mlir::SmallVector<bool> isByRefVec;
isByRefVec.resize(privateTypes.size(), false);
DenseBoolArrayAttr isByRef =
Expand All @@ -587,6 +578,15 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
privateTypes, isByRef, privateSyms);
}

if (reductionSyms) {
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin + privateVars.size(),
argsBegin + privateVars.size() +
reductionTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
reductionTypes, reductionByref, reductionSyms);
}

p.printRegion(region, /*printEntryBlockArgs=*/false);
}

Expand Down
Loading
Loading