-
Notifications
You must be signed in to change notification settings - Fork 15k
[flang][OpenMP] General utility to get directive id from AST node #150121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fortran::parser::omp::GetOmpDirectiveName(t) will get the OmpDirectiveName object that corresponds to construct t. That object (an AST node) contains the enum id and the source information of the directive. Replace uses of extractOmpDirective and getOpenMPDirectiveEnum with the new function.
@llvm/pr-subscribers-flang-fir-hlfir Author: Krzysztof Parzyszek (kparzysz) ChangesFortran::parser::omp::GetOmpDirectiveName(t) will get the OmpDirectiveName object that corresponds to construct t. That object (an AST node) contains the enum id and the source information of the directive. Replace uses of extractOmpDirective and getOpenMPDirectiveEnum with the new function. Full diff: https://github.com/llvm/llvm-project/pull/150121.diff 5 Files Affected:
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
new file mode 100644
index 0000000000000..363db852416f6
--- /dev/null
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -0,0 +1,175 @@
+//===-- flang/Parser/openmp-utils.h ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Common OpenMP utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
+#define FORTRAN_PARSER_OPENMP_UTILS_H
+
+#include "flang/Common/indirection.h"
+#include "flang/Parser/parse-tree.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+
+#include <cassert>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+#include <variant>
+
+namespace Fortran::parser::omp {
+
+namespace detail {
+using D = llvm::omp::Directive;
+
+template <typename Construct> //
+struct ConstructId {
+ static constexpr llvm::omp::Directive id{D::OMPD_unknown};
+};
+
+#define MAKE_CONSTR_ID(Construct, Id) \
+ template <> struct ConstructId<Construct> { \
+ static constexpr llvm::omp::Directive id{Id}; \
+ }
+
+MAKE_CONSTR_ID(OmpAssumeDirective, D::OMPD_assume);
+MAKE_CONSTR_ID(OmpCriticalDirective, D::OMPD_critical);
+MAKE_CONSTR_ID(OmpDeclareVariantDirective, D::OMPD_declare_variant);
+MAKE_CONSTR_ID(OmpErrorDirective, D::OMPD_error);
+MAKE_CONSTR_ID(OmpMetadirectiveDirective, D::OMPD_metadirective);
+MAKE_CONSTR_ID(OpenMPDeclarativeAllocate, D::OMPD_allocate);
+MAKE_CONSTR_ID(OpenMPDeclarativeAssumes, D::OMPD_assumes);
+MAKE_CONSTR_ID(OpenMPDeclareMapperConstruct, D::OMPD_declare_mapper);
+MAKE_CONSTR_ID(OpenMPDeclareReductionConstruct, D::OMPD_declare_reduction);
+MAKE_CONSTR_ID(OpenMPDeclareSimdConstruct, D::OMPD_declare_simd);
+MAKE_CONSTR_ID(OpenMPDeclareTargetConstruct, D::OMPD_declare_target);
+MAKE_CONSTR_ID(OpenMPExecutableAllocate, D::OMPD_allocate);
+MAKE_CONSTR_ID(OpenMPRequiresConstruct, D::OMPD_requires);
+MAKE_CONSTR_ID(OpenMPThreadprivate, D::OMPD_threadprivate);
+
+#undef MAKE_CONSTR_ID
+
+struct DirectiveNameScope {
+ // Helper types to make overloaded function signatures different.
+ struct TagA {};
+ struct TagB {};
+ struct TagC {};
+ struct TagD {};
+
+ static OmpDirectiveName MakeName(CharBlock source = {},
+ llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) {
+ OmpDirectiveName name;
+ name.source = source;
+ name.v = id;
+ return name;
+ }
+
+ static OmpDirectiveName GetOmpDirectiveName(const OmpNothingDirective &x) {
+ return MakeName(x.source, llvm::omp::Directive::OMPD_nothing);
+ }
+
+ static OmpDirectiveName GetOmpDirectiveName(const OmpBeginBlockDirective &x) {
+ auto &dir{std::get<OmpBlockDirective>(x.t)};
+ return MakeName(dir.source, dir.v);
+ }
+
+ static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
+ auto &dir{std::get<OmpLoopDirective>(x.t)};
+ return MakeName(dir.source, dir.v);
+ }
+
+ static OmpDirectiveName GetOmpDirectiveName(
+ const OmpBeginSectionsDirective &x) {
+ auto &dir{std::get<OmpSectionsDirective>(x.t)};
+ return MakeName(dir.source, dir.v);
+ }
+
+ template <typename T, typename = std::enable_if_t<WrapperTrait<T>>>
+ static OmpDirectiveName GetOmpDirectiveName(const T &x, TagA = {}) {
+ if constexpr (std::is_same_v<T, OpenMPCancelConstruct> ||
+ std::is_same_v<T, OpenMPCancellationPointConstruct> ||
+ std::is_same_v<T, OpenMPDepobjConstruct> ||
+ std::is_same_v<T, OpenMPFlushConstruct> ||
+ std::is_same_v<T, OpenMPInteropConstruct> ||
+ std::is_same_v<T, OpenMPSimpleStandaloneConstruct>) {
+ return x.v.DirName();
+ } else {
+ return GetOmpDirectiveName(x.v);
+ }
+ }
+
+ template <typename T, typename = std::enable_if_t<TupleTrait<T>>>
+ static OmpDirectiveName GetOmpDirectiveName(const T &x, TagB = {}) {
+ if constexpr (std::is_same_v<T, OpenMPAllocatorsConstruct> ||
+ std::is_same_v<T, OpenMPAtomicConstruct> ||
+ std::is_same_v<T, OpenMPDispatchConstruct>) {
+ return std::get<OmpDirectiveSpecification>(x.t).DirName();
+ } else if constexpr (std::is_same_v<T, OmpAssumeDirective> ||
+ std::is_same_v<T, OmpCriticalDirective> ||
+ std::is_same_v<T, OmpDeclareVariantDirective> ||
+ std::is_same_v<T, OmpErrorDirective> ||
+ std::is_same_v<T, OmpMetadirectiveDirective> ||
+ std::is_same_v<T, OpenMPDeclarativeAllocate> ||
+ std::is_same_v<T, OpenMPDeclarativeAssumes> ||
+ std::is_same_v<T, OpenMPDeclareMapperConstruct> ||
+ std::is_same_v<T, OpenMPDeclareReductionConstruct> ||
+ std::is_same_v<T, OpenMPDeclareSimdConstruct> ||
+ std::is_same_v<T, OpenMPDeclareTargetConstruct> ||
+ std::is_same_v<T, OpenMPExecutableAllocate> ||
+ std::is_same_v<T, OpenMPRequiresConstruct> ||
+ std::is_same_v<T, OpenMPThreadprivate>) {
+ return MakeName(std::get<Verbatim>(x.t).source, ConstructId<T>::id);
+ } else {
+ return GetFromTuple(
+ x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
+ }
+ }
+
+ template <typename T, typename = std::enable_if_t<UnionTrait<T>>>
+ static OmpDirectiveName GetOmpDirectiveName(const T &x, TagC = {}) {
+ return common::visit([](auto &&s) { return GetOmpDirectiveName(s); }, x.u);
+ }
+
+ template <typename... Ts, size_t... Is>
+ static OmpDirectiveName GetFromTuple(
+ const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
+ OmpDirectiveName name = MakeName();
+ auto accumulate = [&](const OmpDirectiveName &n) {
+ if (name.v == llvm::omp::Directive::OMPD_unknown) {
+ name = n;
+ } else {
+ assert(
+ n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names");
+ }
+ };
+ (accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
+ return name;
+ }
+
+ template <typename T>
+ static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) {
+ return GetOmpDirectiveName(x.value());
+ }
+
+ template <typename T,
+ typename = std::enable_if_t<!WrapperTrait<T> && !TupleTrait<T> &&
+ !UnionTrait<T>>>
+ static OmpDirectiveName GetOmpDirectiveName(const T &x, TagD = {}) {
+ return MakeName();
+ }
+};
+} // namespace detail
+
+template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
+ return detail::DirectiveNameScope::GetOmpDirectiveName(x);
+}
+
+} // namespace Fortran::parser::omp
+
+#endif // FORTRAN_PARSER_OPENMP_UTILS_H
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 11e488371b886..2ac4d9548b65b 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -24,6 +24,7 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Parser/openmp-utils.h"
#include "flang/Semantics/attr.h"
#include "flang/Semantics/tools.h"
#include "llvm/ADT/Sequence.h"
@@ -465,7 +466,8 @@ bool DataSharingProcessor::isOpenMPPrivatizingConstruct(
// allow a privatizing clause) are: dispatch, distribute, do, for, loop,
// parallel, scope, sections, simd, single, target, target_data, task,
// taskgroup, taskloop, and teams.
- return llvm::is_contained(privatizing, extractOmpDirective(omp));
+ return llvm::is_contained(privatizing,
+ parser::omp::GetOmpDirectiveName(omp).v);
}
bool DataSharingProcessor::isOpenMPPrivatizingEvaluation(
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index fc5fef9b2c577..4c2d7badef382 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -31,6 +31,7 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Parser/characters.h"
+#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/tools.h"
@@ -63,28 +64,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
lower::pft::Evaluation &eval,
mlir::Location loc);
-static llvm::omp::Directive
-getOpenMPDirectiveEnum(const parser::OmpLoopDirective &beginStatment) {
- return beginStatment.v;
-}
-
-static llvm::omp::Directive getOpenMPDirectiveEnum(
- const parser::OmpBeginLoopDirective &beginLoopDirective) {
- return getOpenMPDirectiveEnum(
- std::get<parser::OmpLoopDirective>(beginLoopDirective.t));
-}
-
-static llvm::omp::Directive
-getOpenMPDirectiveEnum(const parser::OpenMPLoopConstruct &ompLoopConstruct) {
- return getOpenMPDirectiveEnum(
- std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t));
-}
-
-static llvm::omp::Directive getOpenMPDirectiveEnum(
- const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) {
- return getOpenMPDirectiveEnum(ompLoopConstruct.value());
-}
-
namespace {
/// Structure holding information that is needed to pass host-evaluated
/// information to later lowering stages.
@@ -468,7 +447,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
llvm::omp::Directive dir;
auto &nested = parent.getFirstNestedEvaluation();
if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
- dir = extractOmpDirective(*ompEval);
+ dir = parser::omp::GetOmpDirectiveName(*ompEval).v;
else
return std::nullopt;
@@ -508,7 +487,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
assert(hostInfo && "expected HOST_EVAL info structure");
- switch (extractOmpDirective(*ompEval)) {
+ switch (parser::omp::GetOmpDirectiveName(*ompEval).v) {
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd:
cp.processThreadLimit(stmtCtx, hostInfo->ops);
@@ -569,7 +548,8 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
assert(ompEval &&
- llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
+ llvm::omp::allTargetSet.test(
+ parser::omp::GetOmpDirectiveName(*ompEval).v) &&
"expected TARGET construct evaluation");
(void)ompEval;
@@ -3872,7 +3852,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&*optLoopCons)}) {
llvm::omp::Directive nestedDirective =
- getOpenMPDirectiveEnum(*ompNestedLoopCons);
+ parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
case llvm::omp::Directive::OMPD_tile:
// Emit the omp.loop_nest with annotation for tiling
@@ -3889,7 +3869,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
}
}
- llvm::omp::Directive directive = getOpenMPDirectiveEnum(beginLoopDirective);
+ llvm::omp::Directive directive =
+ parser::omp::GetOmpDirectiveName(beginLoopDirective).v;
const parser::CharBlock &source =
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source;
ConstructQueue queue{
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index b1716d6afb200..13fda978c5369 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -20,6 +20,7 @@
#include <flang/Lower/PFTBuilder.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
+#include <flang/Parser/openmp-utils.h>
#include <flang/Parser/parse-tree.h>
#include <flang/Parser/tools.h>
#include <flang/Semantics/tools.h>
@@ -663,89 +664,6 @@ bool collectLoopRelatedInfo(
return found;
}
-/// Get the directive enumeration value corresponding to the given OpenMP
-/// construct PFT node.
-llvm::omp::Directive
-extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
- return common::visit(
- common::visitors{
- [](const parser::OpenMPAllocatorsConstruct &c) {
- return llvm::omp::OMPD_allocators;
- },
- [](const parser::OpenMPAssumeConstruct &c) {
- return llvm::omp::OMPD_assume;
- },
- [](const parser::OpenMPAtomicConstruct &c) {
- return llvm::omp::OMPD_atomic;
- },
- [](const parser::OpenMPBlockConstruct &c) {
- return std::get<parser::OmpBlockDirective>(
- std::get<parser::OmpBeginBlockDirective>(c.t).t)
- .v;
- },
- [](const parser::OpenMPCriticalConstruct &c) {
- return llvm::omp::OMPD_critical;
- },
- [](const parser::OpenMPDeclarativeAllocate &c) {
- return llvm::omp::OMPD_allocate;
- },
- [](const parser::OpenMPDispatchConstruct &c) {
- return llvm::omp::OMPD_dispatch;
- },
- [](const parser::OpenMPExecutableAllocate &c) {
- return llvm::omp::OMPD_allocate;
- },
- [](const parser::OpenMPLoopConstruct &c) {
- return std::get<parser::OmpLoopDirective>(
- std::get<parser::OmpBeginLoopDirective>(c.t).t)
- .v;
- },
- [](const parser::OpenMPSectionConstruct &c) {
- return llvm::omp::OMPD_section;
- },
- [](const parser::OpenMPSectionsConstruct &c) {
- return std::get<parser::OmpSectionsDirective>(
- std::get<parser::OmpBeginSectionsDirective>(c.t).t)
- .v;
- },
- [](const parser::OpenMPStandaloneConstruct &c) {
- return common::visit(
- common::visitors{
- [](const parser::OpenMPSimpleStandaloneConstruct &c) {
- return c.v.DirId();
- },
- [](const parser::OpenMPFlushConstruct &c) {
- return llvm::omp::OMPD_flush;
- },
- [](const parser::OpenMPCancelConstruct &c) {
- return llvm::omp::OMPD_cancel;
- },
- [](const parser::OpenMPCancellationPointConstruct &c) {
- return llvm::omp::OMPD_cancellation_point;
- },
- [](const parser::OmpMetadirectiveDirective &c) {
- return llvm::omp::OMPD_metadirective;
- },
- [](const parser::OpenMPDepobjConstruct &c) {
- return llvm::omp::OMPD_depobj;
- },
- [](const parser::OpenMPInteropConstruct &c) {
- return llvm::omp::OMPD_interop;
- }},
- c.u);
- },
- [](const parser::OpenMPUtilityConstruct &c) {
- return common::visit(
- common::visitors{[](const parser::OmpErrorDirective &c) {
- return llvm::omp::OMPD_error;
- },
- [](const parser::OmpNothingDirective &c) {
- return llvm::omp::OMPD_nothing;
- }},
- c.u);
- }},
- ompConstruct.u);
-}
} // namespace omp
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 8e3ad5c3452e2..11641ba5e8606 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -167,8 +167,6 @@ bool collectLoopRelatedInfo(
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
-llvm::omp::Directive
-extractOmpDirective(const parser::OpenMPConstruct &ompConstruct);
} // namespace omp
} // namespace lower
} // namespace Fortran
|
@llvm/pr-subscribers-flang-parser Author: Krzysztof Parzyszek (kparzysz) ChangesFortran::parser::omp::GetOmpDirectiveName(t) will get the OmpDirectiveName object that corresponds to construct t. That object (an AST node) contains the enum id and the source information of the directive. Replace uses of extractOmpDirective and getOpenMPDirectiveEnum with the new function. Full diff: https://github.com/llvm/llvm-project/pull/150121.diff 5 Files Affected:
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
new file mode 100644
index 0000000000000..363db852416f6
--- /dev/null
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -0,0 +1,175 @@
+//===-- flang/Parser/openmp-utils.h ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Common OpenMP utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_PARSER_OPENMP_UTILS_H
+#define FORTRAN_PARSER_OPENMP_UTILS_H
+
+#include "flang/Common/indirection.h"
+#include "flang/Parser/parse-tree.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+
+#include <cassert>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+#include <variant>
+
+namespace Fortran::parser::omp {
+
+namespace detail {
+using D = llvm::omp::Directive;
+
+template <typename Construct> //
+struct ConstructId {
+ static constexpr llvm::omp::Directive id{D::OMPD_unknown};
+};
+
+#define MAKE_CONSTR_ID(Construct, Id) \
+ template <> struct ConstructId<Construct> { \
+ static constexpr llvm::omp::Directive id{Id}; \
+ }
+
+MAKE_CONSTR_ID(OmpAssumeDirective, D::OMPD_assume);
+MAKE_CONSTR_ID(OmpCriticalDirective, D::OMPD_critical);
+MAKE_CONSTR_ID(OmpDeclareVariantDirective, D::OMPD_declare_variant);
+MAKE_CONSTR_ID(OmpErrorDirective, D::OMPD_error);
+MAKE_CONSTR_ID(OmpMetadirectiveDirective, D::OMPD_metadirective);
+MAKE_CONSTR_ID(OpenMPDeclarativeAllocate, D::OMPD_allocate);
+MAKE_CONSTR_ID(OpenMPDeclarativeAssumes, D::OMPD_assumes);
+MAKE_CONSTR_ID(OpenMPDeclareMapperConstruct, D::OMPD_declare_mapper);
+MAKE_CONSTR_ID(OpenMPDeclareReductionConstruct, D::OMPD_declare_reduction);
+MAKE_CONSTR_ID(OpenMPDeclareSimdConstruct, D::OMPD_declare_simd);
+MAKE_CONSTR_ID(OpenMPDeclareTargetConstruct, D::OMPD_declare_target);
+MAKE_CONSTR_ID(OpenMPExecutableAllocate, D::OMPD_allocate);
+MAKE_CONSTR_ID(OpenMPRequiresConstruct, D::OMPD_requires);
+MAKE_CONSTR_ID(OpenMPThreadprivate, D::OMPD_threadprivate);
+
+#undef MAKE_CONSTR_ID
+
+struct DirectiveNameScope {
+ // Helper types to make overloaded function signatures different.
+ struct TagA {};
+ struct TagB {};
+ struct TagC {};
+ struct TagD {};
+
+ static OmpDirectiveName MakeName(CharBlock source = {},
+ llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown) {
+ OmpDirectiveName name;
+ name.source = source;
+ name.v = id;
+ return name;
+ }
+
+ static OmpDirectiveName GetOmpDirectiveName(const OmpNothingDirective &x) {
+ return MakeName(x.source, llvm::omp::Directive::OMPD_nothing);
+ }
+
+ static OmpDirectiveName GetOmpDirectiveName(const OmpBeginBlockDirective &x) {
+ auto &dir{std::get<OmpBlockDirective>(x.t)};
+ return MakeName(dir.source, dir.v);
+ }
+
+ static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
+ auto &dir{std::get<OmpLoopDirective>(x.t)};
+ return MakeName(dir.source, dir.v);
+ }
+
+ static OmpDirectiveName GetOmpDirectiveName(
+ const OmpBeginSectionsDirective &x) {
+ auto &dir{std::get<OmpSectionsDirective>(x.t)};
+ return MakeName(dir.source, dir.v);
+ }
+
+ template <typename T, typename = std::enable_if_t<WrapperTrait<T>>>
+ static OmpDirectiveName GetOmpDirectiveName(const T &x, TagA = {}) {
+ if constexpr (std::is_same_v<T, OpenMPCancelConstruct> ||
+ std::is_same_v<T, OpenMPCancellationPointConstruct> ||
+ std::is_same_v<T, OpenMPDepobjConstruct> ||
+ std::is_same_v<T, OpenMPFlushConstruct> ||
+ std::is_same_v<T, OpenMPInteropConstruct> ||
+ std::is_same_v<T, OpenMPSimpleStandaloneConstruct>) {
+ return x.v.DirName();
+ } else {
+ return GetOmpDirectiveName(x.v);
+ }
+ }
+
+ template <typename T, typename = std::enable_if_t<TupleTrait<T>>>
+ static OmpDirectiveName GetOmpDirectiveName(const T &x, TagB = {}) {
+ if constexpr (std::is_same_v<T, OpenMPAllocatorsConstruct> ||
+ std::is_same_v<T, OpenMPAtomicConstruct> ||
+ std::is_same_v<T, OpenMPDispatchConstruct>) {
+ return std::get<OmpDirectiveSpecification>(x.t).DirName();
+ } else if constexpr (std::is_same_v<T, OmpAssumeDirective> ||
+ std::is_same_v<T, OmpCriticalDirective> ||
+ std::is_same_v<T, OmpDeclareVariantDirective> ||
+ std::is_same_v<T, OmpErrorDirective> ||
+ std::is_same_v<T, OmpMetadirectiveDirective> ||
+ std::is_same_v<T, OpenMPDeclarativeAllocate> ||
+ std::is_same_v<T, OpenMPDeclarativeAssumes> ||
+ std::is_same_v<T, OpenMPDeclareMapperConstruct> ||
+ std::is_same_v<T, OpenMPDeclareReductionConstruct> ||
+ std::is_same_v<T, OpenMPDeclareSimdConstruct> ||
+ std::is_same_v<T, OpenMPDeclareTargetConstruct> ||
+ std::is_same_v<T, OpenMPExecutableAllocate> ||
+ std::is_same_v<T, OpenMPRequiresConstruct> ||
+ std::is_same_v<T, OpenMPThreadprivate>) {
+ return MakeName(std::get<Verbatim>(x.t).source, ConstructId<T>::id);
+ } else {
+ return GetFromTuple(
+ x.t, std::make_index_sequence<std::tuple_size_v<decltype(x.t)>>{});
+ }
+ }
+
+ template <typename T, typename = std::enable_if_t<UnionTrait<T>>>
+ static OmpDirectiveName GetOmpDirectiveName(const T &x, TagC = {}) {
+ return common::visit([](auto &&s) { return GetOmpDirectiveName(s); }, x.u);
+ }
+
+ template <typename... Ts, size_t... Is>
+ static OmpDirectiveName GetFromTuple(
+ const std::tuple<Ts...> &t, std::index_sequence<Is...>) {
+ OmpDirectiveName name = MakeName();
+ auto accumulate = [&](const OmpDirectiveName &n) {
+ if (name.v == llvm::omp::Directive::OMPD_unknown) {
+ name = n;
+ } else {
+ assert(
+ n.v == llvm::omp::Directive::OMPD_unknown && "Conflicting names");
+ }
+ };
+ (accumulate(GetOmpDirectiveName(std::get<Is>(t))), ...);
+ return name;
+ }
+
+ template <typename T>
+ static OmpDirectiveName GetOmpDirectiveName(const common::Indirection<T> &x) {
+ return GetOmpDirectiveName(x.value());
+ }
+
+ template <typename T,
+ typename = std::enable_if_t<!WrapperTrait<T> && !TupleTrait<T> &&
+ !UnionTrait<T>>>
+ static OmpDirectiveName GetOmpDirectiveName(const T &x, TagD = {}) {
+ return MakeName();
+ }
+};
+} // namespace detail
+
+template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
+ return detail::DirectiveNameScope::GetOmpDirectiveName(x);
+}
+
+} // namespace Fortran::parser::omp
+
+#endif // FORTRAN_PARSER_OPENMP_UTILS_H
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 11e488371b886..2ac4d9548b65b 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -24,6 +24,7 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Parser/openmp-utils.h"
#include "flang/Semantics/attr.h"
#include "flang/Semantics/tools.h"
#include "llvm/ADT/Sequence.h"
@@ -465,7 +466,8 @@ bool DataSharingProcessor::isOpenMPPrivatizingConstruct(
// allow a privatizing clause) are: dispatch, distribute, do, for, loop,
// parallel, scope, sections, simd, single, target, target_data, task,
// taskgroup, taskloop, and teams.
- return llvm::is_contained(privatizing, extractOmpDirective(omp));
+ return llvm::is_contained(privatizing,
+ parser::omp::GetOmpDirectiveName(omp).v);
}
bool DataSharingProcessor::isOpenMPPrivatizingEvaluation(
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index fc5fef9b2c577..4c2d7badef382 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -31,6 +31,7 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Parser/characters.h"
+#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/tools.h"
@@ -63,28 +64,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
lower::pft::Evaluation &eval,
mlir::Location loc);
-static llvm::omp::Directive
-getOpenMPDirectiveEnum(const parser::OmpLoopDirective &beginStatment) {
- return beginStatment.v;
-}
-
-static llvm::omp::Directive getOpenMPDirectiveEnum(
- const parser::OmpBeginLoopDirective &beginLoopDirective) {
- return getOpenMPDirectiveEnum(
- std::get<parser::OmpLoopDirective>(beginLoopDirective.t));
-}
-
-static llvm::omp::Directive
-getOpenMPDirectiveEnum(const parser::OpenMPLoopConstruct &ompLoopConstruct) {
- return getOpenMPDirectiveEnum(
- std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t));
-}
-
-static llvm::omp::Directive getOpenMPDirectiveEnum(
- const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) {
- return getOpenMPDirectiveEnum(ompLoopConstruct.value());
-}
-
namespace {
/// Structure holding information that is needed to pass host-evaluated
/// information to later lowering stages.
@@ -468,7 +447,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
llvm::omp::Directive dir;
auto &nested = parent.getFirstNestedEvaluation();
if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
- dir = extractOmpDirective(*ompEval);
+ dir = parser::omp::GetOmpDirectiveName(*ompEval).v;
else
return std::nullopt;
@@ -508,7 +487,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
assert(hostInfo && "expected HOST_EVAL info structure");
- switch (extractOmpDirective(*ompEval)) {
+ switch (parser::omp::GetOmpDirectiveName(*ompEval).v) {
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd:
cp.processThreadLimit(stmtCtx, hostInfo->ops);
@@ -569,7 +548,8 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
assert(ompEval &&
- llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
+ llvm::omp::allTargetSet.test(
+ parser::omp::GetOmpDirectiveName(*ompEval).v) &&
"expected TARGET construct evaluation");
(void)ompEval;
@@ -3872,7 +3852,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&*optLoopCons)}) {
llvm::omp::Directive nestedDirective =
- getOpenMPDirectiveEnum(*ompNestedLoopCons);
+ parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
case llvm::omp::Directive::OMPD_tile:
// Emit the omp.loop_nest with annotation for tiling
@@ -3889,7 +3869,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
}
}
- llvm::omp::Directive directive = getOpenMPDirectiveEnum(beginLoopDirective);
+ llvm::omp::Directive directive =
+ parser::omp::GetOmpDirectiveName(beginLoopDirective).v;
const parser::CharBlock &source =
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source;
ConstructQueue queue{
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index b1716d6afb200..13fda978c5369 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -20,6 +20,7 @@
#include <flang/Lower/PFTBuilder.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
+#include <flang/Parser/openmp-utils.h>
#include <flang/Parser/parse-tree.h>
#include <flang/Parser/tools.h>
#include <flang/Semantics/tools.h>
@@ -663,89 +664,6 @@ bool collectLoopRelatedInfo(
return found;
}
-/// Get the directive enumeration value corresponding to the given OpenMP
-/// construct PFT node.
-llvm::omp::Directive
-extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
- return common::visit(
- common::visitors{
- [](const parser::OpenMPAllocatorsConstruct &c) {
- return llvm::omp::OMPD_allocators;
- },
- [](const parser::OpenMPAssumeConstruct &c) {
- return llvm::omp::OMPD_assume;
- },
- [](const parser::OpenMPAtomicConstruct &c) {
- return llvm::omp::OMPD_atomic;
- },
- [](const parser::OpenMPBlockConstruct &c) {
- return std::get<parser::OmpBlockDirective>(
- std::get<parser::OmpBeginBlockDirective>(c.t).t)
- .v;
- },
- [](const parser::OpenMPCriticalConstruct &c) {
- return llvm::omp::OMPD_critical;
- },
- [](const parser::OpenMPDeclarativeAllocate &c) {
- return llvm::omp::OMPD_allocate;
- },
- [](const parser::OpenMPDispatchConstruct &c) {
- return llvm::omp::OMPD_dispatch;
- },
- [](const parser::OpenMPExecutableAllocate &c) {
- return llvm::omp::OMPD_allocate;
- },
- [](const parser::OpenMPLoopConstruct &c) {
- return std::get<parser::OmpLoopDirective>(
- std::get<parser::OmpBeginLoopDirective>(c.t).t)
- .v;
- },
- [](const parser::OpenMPSectionConstruct &c) {
- return llvm::omp::OMPD_section;
- },
- [](const parser::OpenMPSectionsConstruct &c) {
- return std::get<parser::OmpSectionsDirective>(
- std::get<parser::OmpBeginSectionsDirective>(c.t).t)
- .v;
- },
- [](const parser::OpenMPStandaloneConstruct &c) {
- return common::visit(
- common::visitors{
- [](const parser::OpenMPSimpleStandaloneConstruct &c) {
- return c.v.DirId();
- },
- [](const parser::OpenMPFlushConstruct &c) {
- return llvm::omp::OMPD_flush;
- },
- [](const parser::OpenMPCancelConstruct &c) {
- return llvm::omp::OMPD_cancel;
- },
- [](const parser::OpenMPCancellationPointConstruct &c) {
- return llvm::omp::OMPD_cancellation_point;
- },
- [](const parser::OmpMetadirectiveDirective &c) {
- return llvm::omp::OMPD_metadirective;
- },
- [](const parser::OpenMPDepobjConstruct &c) {
- return llvm::omp::OMPD_depobj;
- },
- [](const parser::OpenMPInteropConstruct &c) {
- return llvm::omp::OMPD_interop;
- }},
- c.u);
- },
- [](const parser::OpenMPUtilityConstruct &c) {
- return common::visit(
- common::visitors{[](const parser::OmpErrorDirective &c) {
- return llvm::omp::OMPD_error;
- },
- [](const parser::OmpNothingDirective &c) {
- return llvm::omp::OMPD_nothing;
- }},
- c.u);
- }},
- ompConstruct.u);
-}
} // namespace omp
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 8e3ad5c3452e2..11641ba5e8606 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -167,8 +167,6 @@ bool collectLoopRelatedInfo(
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
-llvm::omp::Directive
-extractOmpDirective(const parser::OpenMPConstruct &ompConstruct);
} // namespace omp
} // namespace lower
} // namespace Fortran
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using tag parameters, is it possible to instead
template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
if constexpr (WrapperTrait<T>) {
...
} else
return MakeName();
This seems to better represent the logic of the tagged overloads and is less costly in terms of compile-time.
Done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…vm#150121) Fortran::parser::omp::GetOmpDirectiveName(t) will get the OmpDirectiveName object that corresponds to construct t. That object (an AST node) contains the enum id and the source information of the directive. Replace uses of extractOmpDirective and getOpenMPDirectiveEnum with the new function.
Fortran::parser::omp::GetOmpDirectiveName(t) will get the OmpDirectiveName object that corresponds to construct t. That object (an AST node) contains the enum id and the source information of the directive.
Replace uses of extractOmpDirective and getOpenMPDirectiveEnum with the new function.