Skip to content

Commit 0bd0ff8

Browse files
committed
[Clang][OpenMP] Allow num_teams to accept multiple expressions
1 parent d798d3b commit 0bd0ff8

20 files changed

+211
-68
lines changed

clang/docs/OpenMPSupport.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,5 +363,7 @@ considered for standardization. Please post on the
363363
| device extension | `'ompx_bare' clause on 'target teams' construct | :good:`prototyped` | #66844, #70612 |
364364
| | <https://www.osti.gov/servlets/purl/2205717>`_ | | |
365365
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
366+
| device extension | Multi-dim `'num_teams' clause on 'target teams ompx_bare' construct | :good:`partial` | #99732, #101407 |
367+
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
366368

367369
.. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35

clang/docs/ReleaseNotes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,9 @@ Improvements
331331
^^^^^^^^^^^^
332332
- Improve the handling of mapping array-section for struct containing nested structs with user defined mappers
333333

334+
- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
335+
This allows the target region to be launched with multi-dim grid on GPUs.
336+
334337
Additional Information
335338
======================
336339

clang/include/clang/AST/OpenMPClause.h

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6369,60 +6369,77 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
63696369
/// \endcode
63706370
/// In this example directive '#pragma omp teams' has clause 'num_teams'
63716371
/// with single expression 'n'.
6372-
class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
6373-
friend class OMPClauseReader;
6372+
///
6373+
/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
6374+
/// can accept up to three expressions.
6375+
///
6376+
/// \code
6377+
/// #pragma omp target teams ompx_bare num_teams(x, y, z)
6378+
/// \endcode
6379+
class OMPNumTeamsClause final
6380+
: public OMPVarListClause<OMPNumTeamsClause>,
6381+
public OMPClauseWithPreInit,
6382+
private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
6383+
friend OMPVarListClause;
6384+
friend TrailingObjects;
63746385

63756386
/// Location of '('.
63766387
SourceLocation LParenLoc;
63776388

6378-
/// NumTeams number.
6379-
Stmt *NumTeams = nullptr;
6389+
OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
6390+
SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
6391+
: OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
6392+
N),
6393+
OMPClauseWithPreInit(this) {}
63806394

6381-
/// Set the NumTeams number.
6382-
///
6383-
/// \param E NumTeams number.
6384-
void setNumTeams(Expr *E) { NumTeams = E; }
6395+
/// Build an empty clause.
6396+
OMPNumTeamsClause(unsigned N)
6397+
: OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6398+
SourceLocation(), SourceLocation(), N),
6399+
OMPClauseWithPreInit(this) {}
63856400

63866401
public:
6387-
/// Build 'num_teams' clause.
6402+
/// Creates clause with a list of variables \a VL.
63886403
///
6389-
/// \param E Expression associated with this clause.
6390-
/// \param HelperE Helper Expression associated with this clause.
6391-
/// \param CaptureRegion Innermost OpenMP region where expressions in this
6392-
/// clause must be captured.
6404+
/// \param C AST context.
63936405
/// \param StartLoc Starting location of the clause.
63946406
/// \param LParenLoc Location of '('.
63956407
/// \param EndLoc Ending location of the clause.
6396-
OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
6397-
SourceLocation StartLoc, SourceLocation LParenLoc,
6398-
SourceLocation EndLoc)
6399-
: OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
6400-
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
6401-
setPreInitStmt(HelperE, CaptureRegion);
6402-
}
6408+
/// \param VL List of references to the variables.
6409+
/// \param PreInit
6410+
static OMPNumTeamsClause *
6411+
Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
6412+
SourceLocation StartLoc, SourceLocation LParenLoc,
6413+
SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);
64036414

6404-
/// Build an empty clause.
6405-
OMPNumTeamsClause()
6406-
: OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6407-
SourceLocation()),
6408-
OMPClauseWithPreInit(this) {}
6415+
/// Creates an empty clause with \a N variables.
6416+
///
6417+
/// \param C AST context.
6418+
/// \param N The number of variables.
6419+
static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
64096420

64106421
/// Sets the location of '('.
64116422
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
64126423

64136424
/// Returns the location of '('.
64146425
SourceLocation getLParenLoc() const { return LParenLoc; }
64156426

6416-
/// Return NumTeams number.
6417-
Expr *getNumTeams() { return cast<Expr>(NumTeams); }
6427+
/// Return NumTeams expressions.
6428+
ArrayRef<Expr *> getNumTeams() { return getVarRefs(); }
64186429

6419-
/// Return NumTeams number.
6420-
Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
6430+
/// Return NumTeams expressions.
6431+
ArrayRef<Expr *> getNumTeams() const {
6432+
return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
6433+
}
64216434

6422-
child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
6435+
child_range children() {
6436+
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
6437+
reinterpret_cast<Stmt **>(varlist_end()));
6438+
}
64236439

64246440
const_child_range children() const {
6425-
return const_child_range(&NumTeams, &NumTeams + 1);
6441+
auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
6442+
return const_child_range(Children.begin(), Children.end());
64266443
}
64276444

64286445
child_range used_children() {

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3828,8 +3828,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
38283828
template <typename Derived>
38293829
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
38303830
OMPNumTeamsClause *C) {
3831+
TRY_TO(VisitOMPClauseList(C));
38313832
TRY_TO(VisitOMPClauseWithPreInit(C));
3832-
TRY_TO(TraverseStmt(C->getNumTeams()));
38333833
return true;
38343834
}
38353835

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11639,6 +11639,8 @@ def warn_omp_unterminated_declare_target : Warning<
1163911639
InGroup<SourceUsesOpenMP>;
1164011640
def err_ompx_bare_no_grid : Error<
1164111641
"'ompx_bare' clauses requires explicit grid size via 'num_teams' and 'thread_limit' clauses">;
11642+
def err_omp_multi_expr_not_allowed: Error<"only one expression allowed in '%0' clause">;
11643+
def err_ompx_more_than_three_expr_not_allowed: Error<"at most three expressions are allowed in '%0' clause in 'target teams ompx_bare' construct">;
1164211644
} // end of OpenMP category
1164311645

1164411646
let CategoryName = "Related Result Type Issue" in {

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,8 @@ class SemaOpenMP : public SemaBase {
12591259
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
12601260
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
12611261
/// Called on well-formed 'num_teams' clause.
1262-
OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
1262+
OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
1263+
SourceLocation StartLoc,
12631264
SourceLocation LParenLoc,
12641265
SourceLocation EndLoc);
12651266
/// Called on well-formed 'thread_limit' clause.

clang/lib/AST/OpenMPClause.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,24 @@ OMPContainsClause *OMPContainsClause::CreateEmpty(const ASTContext &C,
17551755
return new (Mem) OMPContainsClause(K);
17561756
}
17571757

1758+
OMPNumTeamsClause *OMPNumTeamsClause::Create(
1759+
const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
1760+
SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
1761+
ArrayRef<Expr *> VL, Stmt *PreInit) {
1762+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
1763+
OMPNumTeamsClause *Clause =
1764+
new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
1765+
Clause->setVarRefs(VL);
1766+
Clause->setPreInitStmt(PreInit, CaptureRegion);
1767+
return Clause;
1768+
}
1769+
1770+
OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
1771+
unsigned N) {
1772+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
1773+
return new (Mem) OMPNumTeamsClause(N);
1774+
}
1775+
17581776
//===----------------------------------------------------------------------===//
17591777
// OpenMP clauses printing methods
17601778
//===----------------------------------------------------------------------===//
@@ -2055,9 +2073,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
20552073
}
20562074

20572075
void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
2058-
OS << "num_teams(";
2059-
Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
2060-
OS << ")";
2076+
if (!Node->varlist_empty()) {
2077+
OS << "num_teams";
2078+
VisitOMPClauseList(Node, '(');
2079+
OS << ")";
2080+
}
20612081
}
20622082

20632083
void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {

clang/lib/AST/StmtProfile.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -857,9 +857,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
857857
VisitOMPClauseList(C);
858858
}
859859
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
860+
VisitOMPClauseList(C);
860861
VistOMPClauseWithPreInit(C);
861-
if (C->getNumTeams())
862-
Profiler->VisitStmt(C->getNumTeams());
863862
}
864863
void OMPClauseProfiler::VisitOMPThreadLimitClause(
865864
const OMPThreadLimitClause *C) {

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6036,8 +6036,9 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
60366036
dyn_cast_or_null<OMPExecutableDirective>(ChildStmt)) {
60376037
if (isOpenMPTeamsDirective(NestedDir->getDirectiveKind())) {
60386038
if (NestedDir->hasClausesOfKind<OMPNumTeamsClause>()) {
6039-
const Expr *NumTeams =
6040-
NestedDir->getSingleClause<OMPNumTeamsClause>()->getNumTeams();
6039+
const Expr *NumTeams = NestedDir->getSingleClause<OMPNumTeamsClause>()
6040+
->getNumTeams()
6041+
.front();
60416042
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
60426043
if (auto Constant =
60436044
NumTeams->getIntegerConstantExpr(CGF.getContext()))
@@ -6062,7 +6063,7 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
60626063
case OMPD_target_teams_distribute_parallel_for_simd: {
60636064
if (D.hasClausesOfKind<OMPNumTeamsClause>()) {
60646065
const Expr *NumTeams =
6065-
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams();
6066+
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams().front();
60666067
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
60676068
if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext()))
60686069
MinTeamsVal = MaxTeamsVal = Constant->getExtValue();

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6859,7 +6859,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
68596859
const auto *NT = S.getSingleClause<OMPNumTeamsClause>();
68606860
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
68616861
if (NT || TL) {
6862-
const Expr *NumTeams = NT ? NT->getNumTeams() : nullptr;
6862+
const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
68636863
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;
68646864

68656865
CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,

clang/lib/Parse/ParseOpenMP.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3175,7 +3175,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
31753175
case OMPC_simdlen:
31763176
case OMPC_collapse:
31773177
case OMPC_ordered:
3178-
case OMPC_num_teams:
31793178
case OMPC_thread_limit:
31803179
case OMPC_priority:
31813180
case OMPC_grainsize:
@@ -3332,6 +3331,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
33323331
? ParseOpenMPSimpleClause(CKind, WrongDirective)
33333332
: ParseOpenMPClause(CKind, WrongDirective);
33343333
break;
3334+
case OMPC_num_teams:
3335+
if (!FirstClause) {
3336+
Diag(Tok, diag::err_omp_more_one_clause)
3337+
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
3338+
ErrorFound = true;
3339+
}
3340+
[[clang::fallthrough]];
33353341
case OMPC_private:
33363342
case OMPC_firstprivate:
33373343
case OMPC_lastprivate:

0 commit comments

Comments
 (0)