Skip to content

Commit cbb4e5d

Browse files
committed
[Clang][OpenMP] Allow num_teams to accept multiple expressions
1 parent ad836c1 commit cbb4e5d

20 files changed

+193
-68
lines changed

clang/docs/OpenMPSupport.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,5 +361,7 @@ considered for standardization. Please post on the
361361
| device extension | `'ompx_bare' clause on 'target teams' construct | :good:`prototyped` | #66844, #70612 |
362362
| | <https://www.osti.gov/servlets/purl/2205717>`_ | | |
363363
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
364+
| device extension | Multi-dim `'num_teams' clause on 'target teams ompx_bare' construct | :good:`partial` | #99732, #101407 |
365+
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
364366

365367
.. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35

clang/docs/ReleaseNotes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ Python Binding Changes
307307
OpenMP Support
308308
--------------
309309

310+
- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
311+
This allows the target region to be launched with multi-dim grid on GPUs.
312+
310313
Additional Information
311314
======================
312315

clang/include/clang/AST/OpenMPClause.h

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6131,60 +6131,77 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
61316131
/// \endcode
61326132
/// In this example directive '#pragma omp teams' has clause 'num_teams'
61336133
/// with single expression 'n'.
6134-
class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
6135-
friend class OMPClauseReader;
6134+
///
6135+
/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
6136+
/// can accept up to three expressions.
6137+
///
6138+
/// \code
6139+
/// #pragma omp target teams ompx_bare num_teams(x, y, z)
6140+
/// \endcode
6141+
class OMPNumTeamsClause final
6142+
: public OMPVarListClause<OMPNumTeamsClause>,
6143+
public OMPClauseWithPreInit,
6144+
private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
6145+
friend OMPVarListClause;
6146+
friend TrailingObjects;
61366147

61376148
/// Location of '('.
61386149
SourceLocation LParenLoc;
61396150

6140-
/// NumTeams number.
6141-
Stmt *NumTeams = nullptr;
6151+
OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
6152+
SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
6153+
: OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
6154+
N),
6155+
OMPClauseWithPreInit(this) {}
61426156

6143-
/// Set the NumTeams number.
6144-
///
6145-
/// \param E NumTeams number.
6146-
void setNumTeams(Expr *E) { NumTeams = E; }
6157+
/// Build an empty clause.
6158+
OMPNumTeamsClause(unsigned N)
6159+
: OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6160+
SourceLocation(), SourceLocation(), N),
6161+
OMPClauseWithPreInit(this) {}
61476162

61486163
public:
6149-
/// Build 'num_teams' clause.
6164+
/// Creates clause with a list of variables \a VL.
61506165
///
6151-
/// \param E Expression associated with this clause.
6152-
/// \param HelperE Helper Expression associated with this clause.
6153-
/// \param CaptureRegion Innermost OpenMP region where expressions in this
6154-
/// clause must be captured.
6166+
/// \param C AST context.
61556167
/// \param StartLoc Starting location of the clause.
61566168
/// \param LParenLoc Location of '('.
61576169
/// \param EndLoc Ending location of the clause.
6158-
OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
6159-
SourceLocation StartLoc, SourceLocation LParenLoc,
6160-
SourceLocation EndLoc)
6161-
: OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
6162-
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
6163-
setPreInitStmt(HelperE, CaptureRegion);
6164-
}
6170+
/// \param VL List of references to the variables.
6171+
/// \param PreInit
6172+
static OMPNumTeamsClause *
6173+
Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
6174+
SourceLocation StartLoc, SourceLocation LParenLoc,
6175+
SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);
61656176

6166-
/// Build an empty clause.
6167-
OMPNumTeamsClause()
6168-
: OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6169-
SourceLocation()),
6170-
OMPClauseWithPreInit(this) {}
6177+
/// Creates an empty clause with \a N variables.
6178+
///
6179+
/// \param C AST context.
6180+
/// \param N The number of variables.
6181+
static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
61716182

61726183
/// Sets the location of '('.
61736184
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
61746185

61756186
/// Returns the location of '('.
61766187
SourceLocation getLParenLoc() const { return LParenLoc; }
61776188

6178-
/// Return NumTeams number.
6179-
Expr *getNumTeams() { return cast<Expr>(NumTeams); }
6189+
/// Return NumTeams expressions.
6190+
ArrayRef<Expr *> getNumTeams() { return getVarRefs(); }
61806191

6181-
/// Return NumTeams number.
6182-
Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
6192+
/// Return NumTeams expressions.
6193+
ArrayRef<Expr *> getNumTeams() const {
6194+
return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
6195+
}
61836196

6184-
child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
6197+
child_range children() {
6198+
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
6199+
reinterpret_cast<Stmt **>(varlist_end()));
6200+
}
61856201

61866202
const_child_range children() const {
6187-
return const_child_range(&NumTeams, &NumTeams + 1);
6203+
auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
6204+
return const_child_range(Children.begin(), Children.end());
61886205
}
61896206

61906207
child_range used_children() {

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
37933793
template <typename Derived>
37943794
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
37953795
OMPNumTeamsClause *C) {
3796+
TRY_TO(VisitOMPClauseList(C));
37963797
TRY_TO(VisitOMPClauseWithPreInit(C));
3797-
TRY_TO(TraverseStmt(C->getNumTeams()));
37983798
return true;
37993799
}
38003800

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11639,6 +11639,7 @@ def warn_omp_unterminated_declare_target : Warning<
1163911639
InGroup<SourceUsesOpenMP>;
1164011640
def err_ompx_bare_no_grid : Error<
1164111641
"'ompx_bare' clauses requires explicit grid size via 'num_teams' and 'thread_limit' clauses">;
11642+
def err_omp_multi_expr_not_allowed: Error<"only one expression allowed to '%0' clause">;
1164211643
} // end of OpenMP category
1164311644

1164411645
let CategoryName = "Related Result Type Issue" in {

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,8 @@ class SemaOpenMP : public SemaBase {
12261226
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
12271227
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
12281228
/// Called on well-formed 'num_teams' clause.
1229-
OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
1229+
OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
1230+
SourceLocation StartLoc,
12301231
SourceLocation LParenLoc,
12311232
SourceLocation EndLoc);
12321233
/// Called on well-formed 'thread_limit' clause.

clang/lib/AST/OpenMPClause.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
17201720
return *It;
17211721
}
17221722

1723+
OMPNumTeamsClause *OMPNumTeamsClause::Create(
1724+
const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
1725+
SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
1726+
ArrayRef<Expr *> VL, Stmt *PreInit) {
1727+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
1728+
OMPNumTeamsClause *Clause =
1729+
new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
1730+
Clause->setVarRefs(VL);
1731+
Clause->setPreInitStmt(PreInit, CaptureRegion);
1732+
return Clause;
1733+
}
1734+
1735+
OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
1736+
unsigned N) {
1737+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
1738+
return new (Mem) OMPNumTeamsClause(N);
1739+
}
1740+
17231741
//===----------------------------------------------------------------------===//
17241742
// OpenMP clauses printing methods
17251743
//===----------------------------------------------------------------------===//
@@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
19771995
}
19781996

19791997
void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
1980-
OS << "num_teams(";
1981-
Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
1982-
OS << ")";
1998+
if (!Node->varlist_empty()) {
1999+
OS << "num_teams";
2000+
VisitOMPClauseList(Node, '(');
2001+
OS << ")";
2002+
}
19832003
}
19842004

19852005
void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {

clang/lib/AST/StmtProfile.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
843843
VisitOMPClauseList(C);
844844
}
845845
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
846+
VisitOMPClauseList(C);
846847
VistOMPClauseWithPreInit(C);
847-
if (C->getNumTeams())
848-
Profiler->VisitStmt(C->getNumTeams());
849848
}
850849
void OMPClauseProfiler::VisitOMPThreadLimitClause(
851850
const OMPThreadLimitClause *C) {

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6036,8 +6036,9 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
60366036
dyn_cast_or_null<OMPExecutableDirective>(ChildStmt)) {
60376037
if (isOpenMPTeamsDirective(NestedDir->getDirectiveKind())) {
60386038
if (NestedDir->hasClausesOfKind<OMPNumTeamsClause>()) {
6039-
const Expr *NumTeams =
6040-
NestedDir->getSingleClause<OMPNumTeamsClause>()->getNumTeams();
6039+
const Expr *NumTeams = NestedDir->getSingleClause<OMPNumTeamsClause>()
6040+
->getNumTeams()
6041+
.front();
60416042
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
60426043
if (auto Constant =
60436044
NumTeams->getIntegerConstantExpr(CGF.getContext()))
@@ -6062,7 +6063,7 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
60626063
case OMPD_target_teams_distribute_parallel_for_simd: {
60636064
if (D.hasClausesOfKind<OMPNumTeamsClause>()) {
60646065
const Expr *NumTeams =
6065-
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams();
6066+
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams().front();
60666067
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
60676068
if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext()))
60686069
MinTeamsVal = MaxTeamsVal = Constant->getExtValue();

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6859,7 +6859,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
68596859
const auto *NT = S.getSingleClause<OMPNumTeamsClause>();
68606860
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
68616861
if (NT || TL) {
6862-
const Expr *NumTeams = NT ? NT->getNumTeams() : nullptr;
6862+
const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
68636863
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;
68646864

68656865
CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,

clang/lib/Parse/ParseOpenMP.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
30983098
case OMPC_simdlen:
30993099
case OMPC_collapse:
31003100
case OMPC_ordered:
3101-
case OMPC_num_teams:
31023101
case OMPC_thread_limit:
31033102
case OMPC_priority:
31043103
case OMPC_grainsize:
@@ -3252,6 +3251,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
32523251
? ParseOpenMPSimpleClause(CKind, WrongDirective)
32533252
: ParseOpenMPClause(CKind, WrongDirective);
32543253
break;
3254+
case OMPC_num_teams:
3255+
if (!FirstClause) {
3256+
Diag(Tok, diag::err_omp_more_one_clause)
3257+
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
3258+
ErrorFound = true;
3259+
}
3260+
[[clang::fallthrough]];
32553261
case OMPC_private:
32563262
case OMPC_firstprivate:
32573263
case OMPC_lastprivate:

0 commit comments

Comments
 (0)