Skip to content

[OpenACC] implement loop 'worker' clause. #112206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 14, 2024

Conversation

erichkeane
Copy link
Collaborator

The worker clause specifies iterations of the loop/ that are executed in parallel by distributing the iterations among the multiple works within a single gang.

The sema rules for this type are simply that it cannot be combined with a kernel construct with a num_workers clause, child loop clauses cannot contain a gang or worker clause, and that the argument is oly allowed when associated with a kernel.

The worker clause specifies iterations of the loop/ that are executed in
parallel by distributing the iterations among the multiple works within
a single gang.

The sema rules for this type are simply that it cannot be combined with
a `kerne` construct with a `num_workers` clause, child `loop` clauses
cannot contain a `gang` or `worker` clause, and that the argument is oly
allowed when associated with a `kernel`.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:as-a-library libclang and C++ API labels Oct 14, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2024

@llvm/pr-subscribers-clang

Author: Erich Keane (erichkeane)

Changes

The worker clause specifies iterations of the loop/ that are executed in parallel by distributing the iterations among the multiple works within a single gang.

The sema rules for this type are simply that it cannot be combined with a kernel construct with a num_workers clause, child loop clauses cannot contain a gang or worker clause, and that the argument is oly allowed when associated with a kernel.


Patch is 49.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112206.diff

20 Files Affected:

  • (modified) clang/include/clang/AST/OpenACCClause.h (+16-26)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+11-10)
  • (modified) clang/include/clang/Basic/OpenACCClauses.def (+1)
  • (modified) clang/include/clang/Sema/SemaOpenACC.h (+16)
  • (modified) clang/lib/AST/OpenACCClause.cpp (+27-3)
  • (modified) clang/lib/AST/StmtProfile.cpp (+6)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+1)
  • (modified) clang/lib/Parse/ParseOpenACC.cpp (+1)
  • (modified) clang/lib/Sema/SemaOpenACC.cpp (+115-10)
  • (modified) clang/lib/Sema/TreeTransform.h (+28)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+6-1)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+8-1)
  • (modified) clang/test/AST/ast-print-openacc-loop-construct.cpp (+39)
  • (modified) clang/test/ParserOpenACC/parse-clauses.c (+9-10)
  • (modified) clang/test/SemaOpenACC/compute-construct-device_type-clause.c (+1-2)
  • (modified) clang/test/SemaOpenACC/loop-construct-auto_seq_independent-clauses.c (+4-11)
  • (modified) clang/test/SemaOpenACC/loop-construct-device_type-clause.c (-1)
  • (added) clang/test/SemaOpenACC/loop-construct-worker-ast.cpp (+270)
  • (added) clang/test/SemaOpenACC/loop-construct-worker-clause.cpp (+202)
  • (modified) clang/tools/libclang/CIndex.cpp (+5)
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index f3a09eb651458d..e8b8f477f91ae7 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -145,32 +145,6 @@ class OpenACCVectorClause : public OpenACCClause {
   }
 };
 
-// Not yet implemented, but the type name is necessary for 'seq' diagnostics, so
-// this provides a basic, do-nothing implementation. We still need to add this
-// type to the visitors/etc, as well as get it to take its proper arguments.
-class OpenACCWorkerClause : public OpenACCClause {
-protected:
-  OpenACCWorkerClause(SourceLocation BeginLoc, SourceLocation EndLoc)
-      : OpenACCClause(OpenACCClauseKind::Worker, BeginLoc, EndLoc) {
-    llvm_unreachable("Not yet implemented");
-  }
-
-public:
-  static bool classof(const OpenACCClause *C) {
-    return C->getClauseKind() == OpenACCClauseKind::Worker;
-  }
-
-  static OpenACCWorkerClause *
-  Create(const ASTContext &Ctx, SourceLocation BeginLoc, SourceLocation EndLoc);
-
-  child_range children() {
-    return child_range(child_iterator(), child_iterator());
-  }
-  const_child_range children() const {
-    return const_child_range(const_child_iterator(), const_child_iterator());
-  }
-};
-
 /// Represents a clause that has a list of parameters.
 class OpenACCClauseWithParams : public OpenACCClause {
   /// Location of the '('.
@@ -541,6 +515,22 @@ class OpenACCGangClause final
          ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);
 };
 
+class OpenACCWorkerClause : public OpenACCClauseWithSingleIntExpr {
+protected:
+  OpenACCWorkerClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                      Expr *IntExpr, SourceLocation EndLoc);
+
+public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Worker;
+  }
+
+  static OpenACCWorkerClause *Create(const ASTContext &Ctx,
+                                     SourceLocation BeginLoc,
+                                     SourceLocation LParenLoc, Expr *IntExpr,
+                                     SourceLocation EndLoc);
+};
+
 class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
   OpenACCNumWorkersClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
                           Expr *IntExpr, SourceLocation EndLoc);
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 3c62a017005e59..00201a7f192f17 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12686,21 +12686,22 @@ def err_acc_intervening_code
 def err_acc_gang_multiple_elt
     : Error<"OpenACC 'gang' clause may have at most one %select{unnamed or "
             "'num'|'dim'|'static'}0 argument">;
-def err_acc_gang_arg_invalid
-    : Error<"'%0' argument on 'gang' clause is not permitted on a%select{n "
-            "orphaned|||}1 'loop' construct %select{|associated with a "
+def err_acc_int_arg_invalid
+    : Error<"'%1' argument on '%0' clause is not permitted on a%select{n "
+            "orphaned|||}2 'loop' construct %select{|associated with a "
             "'parallel' compute construct|associated with a 'kernels' compute "
-            "construct|associated with a 'serial' compute construct}1">;
+            "construct|associated with a 'serial' compute construct}2">;
 def err_acc_gang_dim_value
     : Error<"argument to 'gang' clause dimension must be %select{a constant "
             "expression|1, 2, or 3: evaluated to %1}0">;
-def err_acc_gang_num_gangs_conflict
-    : Error<"'num' argument to 'gang' clause not allowed on a 'loop' construct "
-            "associated with a 'kernels' construct that has a 'num_gangs' "
+def err_acc_num_arg_conflict
+    : Error<"'num' argument to '%0' clause not allowed on a 'loop' construct "
+            "associated with a 'kernels' construct that has a "
+            "'%select{num_gangs|num_workers}1' "
             "clause">;
-def err_acc_gang_inside_gang
-    : Error<"loop with a 'gang' clause may not exist in the region of a 'gang' "
-            "clause on a 'kernels' compute construct">;
+def err_acc_clause_in_clause_region
+    : Error<"loop with a '%0' clause may not exist in the region of a '%1' "
+            "clause%select{| on a 'kernels' compute construct}2">;
 
 // AMDGCN builtins diagnostics
 def err_amdgcn_global_load_lds_size_invalid_value : Error<"invalid size value">;
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
index 2a098de31eb618..4c0b56dc13e625 100644
--- a/clang/include/clang/Basic/OpenACCClauses.def
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -56,6 +56,7 @@ VISIT_CLAUSE(Seq)
 VISIT_CLAUSE(Tile)
 VISIT_CLAUSE(VectorLength)
 VISIT_CLAUSE(Wait)
+VISIT_CLAUSE(Worker)
 
 #undef VISIT_CLAUSE
 #undef CLAUSE_ALIAS
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index 59a9648d5f9380..e253610a84b0bf 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -118,6 +118,11 @@ class SemaOpenACC : public SemaBase {
   /// 'kernel' construct, this will have the source location for it. This
   /// permits us to implement the restriction of no further 'gang' clauses.
   SourceLocation LoopGangClauseOnKernelLoc;
+  /// If there is a current 'active' loop construct with a 'worker' clause on it
+  /// (on any sort of construct), this has the source location for it.  This
+  /// permits us to implement the restriction of no further 'gang' or 'worker'
+  /// clauses.
+  SourceLocation LoopWorkerClauseLoc;
 
   // Redeclaration of the version in OpenACCClause.h.
   using DeviceTypeArgument = std::pair<IdentifierInfo *, SourceLocation>;
@@ -224,11 +229,15 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::Async ||
               ClauseKind == OpenACCClauseKind::Tile ||
+              ClauseKind == OpenACCClauseKind::Worker ||
+              ClauseKind == OpenACCClauseKind::Vector ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
 
       // 'async' and 'wait' have an optional IntExpr, so be tolerant of that.
       if ((ClauseKind == OpenACCClauseKind::Async ||
+           ClauseKind == OpenACCClauseKind::Worker ||
+           ClauseKind == OpenACCClauseKind::Vector ||
            ClauseKind == OpenACCClauseKind::Wait) &&
           std::holds_alternative<std::monostate>(Details))
         return 0;
@@ -271,6 +280,8 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::Async ||
               ClauseKind == OpenACCClauseKind::Tile ||
               ClauseKind == OpenACCClauseKind::Gang ||
+              ClauseKind == OpenACCClauseKind::Worker ||
+              ClauseKind == OpenACCClauseKind::Vector ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
 
@@ -401,6 +412,8 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::Async ||
               ClauseKind == OpenACCClauseKind::Tile ||
+              ClauseKind == OpenACCClauseKind::Worker ||
+              ClauseKind == OpenACCClauseKind::Vector ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       Details = IntExprDetails{{IntExprs.begin(), IntExprs.end()}};
@@ -410,6 +423,8 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::Async ||
               ClauseKind == OpenACCClauseKind::Tile ||
+              ClauseKind == OpenACCClauseKind::Worker ||
+              ClauseKind == OpenACCClauseKind::Vector ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       Details = IntExprDetails{std::move(IntExprs)};
@@ -663,6 +678,7 @@ class SemaOpenACC : public SemaBase {
     ComputeConstructInfo OldActiveComputeConstructInfo;
     OpenACCDirectiveKind DirKind;
     SourceLocation OldLoopGangClauseOnKernelLoc;
+    SourceLocation OldLoopWorkerClauseLoc;
     llvm::SmallVector<OpenACCLoopConstruct *> ParentlessLoopConstructs;
     LoopInConstructRAII LoopRAII;
 
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index 6fb8fe0b8cfeef..638252fd811f1d 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -44,7 +44,8 @@ bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
 bool OpenACCClauseWithSingleIntExpr::classof(const OpenACCClause *C) {
   return OpenACCNumWorkersClause::classof(C) ||
          OpenACCVectorLengthClause::classof(C) ||
-         OpenACCCollapseClause::classof(C) || OpenACCAsyncClause::classof(C);
+         OpenACCWorkerClause::classof(C) || OpenACCCollapseClause::classof(C) ||
+         OpenACCAsyncClause::classof(C);
 }
 OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
                                                    OpenACCDefaultClauseKind K,
@@ -403,11 +404,24 @@ OpenACCGangClause::Create(const ASTContext &C, SourceLocation BeginLoc,
       OpenACCGangClause(BeginLoc, LParenLoc, GangKinds, IntExprs, EndLoc);
 }
 
+OpenACCWorkerClause::OpenACCWorkerClause(SourceLocation BeginLoc,
+                                         SourceLocation LParenLoc,
+                                         Expr *IntExpr, SourceLocation EndLoc)
+    : OpenACCClauseWithSingleIntExpr(OpenACCClauseKind::Worker, BeginLoc,
+                                     LParenLoc, IntExpr, EndLoc) {
+  assert((!IntExpr || IntExpr->isInstantiationDependent() ||
+          IntExpr->getType()->isIntegerType()) &&
+         "Int expression type not scalar/dependent");
+}
+
 OpenACCWorkerClause *OpenACCWorkerClause::Create(const ASTContext &C,
                                                  SourceLocation BeginLoc,
+                                                 SourceLocation LParenLoc,
+                                                 Expr *IntExpr,
                                                  SourceLocation EndLoc) {
-  void *Mem = C.Allocate(sizeof(OpenACCWorkerClause));
-  return new (Mem) OpenACCWorkerClause(BeginLoc, EndLoc);
+  void *Mem =
+      C.Allocate(sizeof(OpenACCWorkerClause), alignof(OpenACCWorkerClause));
+  return new (Mem) OpenACCWorkerClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
 }
 
 OpenACCVectorClause *OpenACCVectorClause::Create(const ASTContext &C,
@@ -638,3 +652,13 @@ void OpenACCClausePrinter::VisitGangClause(const OpenACCGangClause &C) {
     OS << ")";
   }
 }
+
+void OpenACCClausePrinter::VisitWorkerClause(const OpenACCWorkerClause &C) {
+  OS << "worker";
+
+  if (C.hasIntExpr()) {
+    OS << "(num: ";
+    printExpr(C.getIntExpr());
+    OS << ")";
+  }
+}
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 6161b1403ed35d..25b1cbb8590869 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2629,6 +2629,12 @@ void OpenACCClauseProfiler::VisitAsyncClause(const OpenACCAsyncClause &Clause) {
     Profiler.VisitStmt(Clause.getIntExpr());
 }
 
+void OpenACCClauseProfiler::VisitWorkerClause(
+    const OpenACCWorkerClause &Clause) {
+  if (Clause.hasIntExpr())
+    Profiler.VisitStmt(Clause.getIntExpr());
+}
+
 void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
   if (Clause.hasDevNumExpr())
     Profiler.VisitStmt(Clause.getDevNumExpr());
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index ac8c196777f9b8..beccb0615f0e9c 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -420,6 +420,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
     case OpenACCClauseKind::Self:
     case OpenACCClauseKind::Seq:
     case OpenACCClauseKind::Tile:
+    case OpenACCClauseKind::Worker:
     case OpenACCClauseKind::VectorLength:
       // The condition expression will be printed as a part of the 'children',
       // but print 'clause' here so it is clear what is happening from the dump.
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index 635039b724e6a0..51d4dc38c17f67 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -1130,6 +1130,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
           Parens.skipToEnd();
           return OpenACCCanContinue();
         }
+        ParsedClause.setIntExprDetails(IntExpr.get());
         break;
       }
       case OpenACCClauseKind::Async: {
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index 30d73d621db69b..1b24331cbd87ca 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -377,6 +377,18 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
     default:
       return false;
     }
+  case OpenACCClauseKind::Worker: {
+    switch (DirectiveKind) {
+    case OpenACCDirectiveKind::Loop:
+    case OpenACCDirectiveKind::ParallelLoop:
+    case OpenACCDirectiveKind::SerialLoop:
+    case OpenACCDirectiveKind::KernelsLoop:
+    case OpenACCDirectiveKind::Routine:
+      return true;
+    default:
+      return false;
+    }
+  }
   }
 
   default:
@@ -500,7 +512,6 @@ class SemaOpenACCClauseVisitor {
 
   OpenACCClause *Visit(SemaOpenACC::OpenACCParsedClause &Clause) {
     switch (Clause.getClauseKind()) {
-    case OpenACCClauseKind::Worker:
     case OpenACCClauseKind::Vector: {
       // TODO OpenACC: These are only implemented enough for the 'seq'
       // diagnostic, otherwise treats itself as unimplemented.  When we
@@ -1024,6 +1035,75 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitIndependentClause(
                                           Clause.getEndLoc());
 }
 
+OpenACCClause *SemaOpenACCClauseVisitor::VisitWorkerClause(
+    SemaOpenACC::OpenACCParsedClause &Clause) {
+  if (DiagIfSeqClause(Clause))
+    return nullptr;
+
+  // Restrictions only properly implemented on 'loop' constructs, and it is
+  // the only construct that can do anything with this, so skip/treat as
+  // unimplemented for the combined constructs.
+  if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Loop)
+    return isNotImplemented();
+
+  Expr *IntExpr =
+      Clause.getNumIntExprs() != 0 ? Clause.getIntExprs()[0] : nullptr;
+
+  if (IntExpr) {
+    switch (SemaRef.getActiveComputeConstructInfo().Kind) {
+    case OpenACCDirectiveKind::Invalid:
+      SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_int_arg_invalid)
+          << OpenACCClauseKind::Worker << "num" << /*orphan=*/0;
+      IntExpr = nullptr;
+      break;
+    case OpenACCDirectiveKind::Parallel:
+      SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_int_arg_invalid)
+          << OpenACCClauseKind::Worker << "num" << /*parallel=*/1;
+      IntExpr = nullptr;
+      break;
+    case OpenACCDirectiveKind::Serial:
+      SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_int_arg_invalid)
+          << OpenACCClauseKind::Worker << "num" << /*serial=*/3;
+      IntExpr = nullptr;
+      break;
+    case OpenACCDirectiveKind::Kernels: {
+      const auto *Itr =
+          llvm::find_if(SemaRef.getActiveComputeConstructInfo().Clauses,
+                        llvm::IsaPred<OpenACCNumWorkersClause>);
+      if (Itr != SemaRef.getActiveComputeConstructInfo().Clauses.end()) {
+        SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_num_arg_conflict)
+            << OpenACCClauseKind::Worker << /*num_workers=*/1;
+        SemaRef.Diag((*Itr)->getBeginLoc(),
+                     diag::note_acc_previous_clause_here);
+
+        IntExpr = nullptr;
+      }
+      break;
+    }
+    default:
+      llvm_unreachable("Non compute construct in active compute construct");
+    }
+  }
+
+  // OpenACC 3.3 2.9.3: The region of a loop with a 'worker' clause may not
+  // contain a loop with a gang or worker clause unless within a nested compute
+  // region.
+  if (SemaRef.LoopWorkerClauseLoc.isValid()) {
+    // This handles the 'inner loop' diagnostic, but we cannot set that we're on
+    // one of these until we get to the end of the construct.
+    SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
+        << OpenACCClauseKind::Worker << OpenACCClauseKind::Worker
+        << /*skip kernels construct info*/ 0;
+    SemaRef.Diag(SemaRef.LoopWorkerClauseLoc,
+                 diag::note_acc_previous_clause_here);
+    return nullptr;
+  }
+
+  return OpenACCWorkerClause::Create(Ctx, Clause.getBeginLoc(),
+                                     Clause.getLParenLoc(), IntExpr,
+                                     Clause.getEndLoc());
+}
+
 OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause(
     SemaOpenACC::OpenACCParsedClause &Clause) {
   if (DiagIfSeqClause(Clause))
@@ -1061,8 +1141,8 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause(
                         llvm::IsaPred<OpenACCNumGangsClause>);
 
       if (Itr != SemaRef.getActiveComputeConstructInfo().Clauses.end()) {
-        SemaRef.Diag(ER.get()->getBeginLoc(),
-                     diag::err_acc_gang_num_gangs_conflict);
+        SemaRef.Diag(ER.get()->getBeginLoc(), diag::err_acc_num_arg_conflict)
+            << OpenACCClauseKind::Gang << /*num_gangs=*/0;
         SemaRef.Diag((*Itr)->getBeginLoc(),
                      diag::note_acc_previous_clause_here);
         continue;
@@ -1091,12 +1171,28 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause(
   if (SemaRef.LoopGangClauseOnKernelLoc.isValid()) {
     // This handles the 'inner loop' diagnostic, but we cannot set that we're on
     // one of these until we get to the end of the construct.
-    SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_gang_inside_gang);
+    SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
+        << OpenACCClauseKind::Gang << OpenACCClauseKind::Gang
+        << /*kernels construct info*/ 1;
     SemaRef.Diag(SemaRef.LoopGangClauseOnKernelLoc,
                  diag::note_acc_previous_clause_here);
     return nullptr;
   }
 
+  // OpenACC 3.3 2.9.3: The region of a loop with a 'worker' clause may not
+  // contain a loop with a gang or worker clause unless within a nested compute
+  // region.
+  if (SemaRef.LoopWorkerClauseLoc.isValid()) {
+    // This handles the 'inner loop' diagnostic, but we cannot set that we're on
+    // one of these until we get to the end of the construct.
+    SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
+        << OpenACCClauseKind::Gang << OpenACCClauseKind::Worker
+        << /*kernels construct info*/ 1;
+    SemaRef.Diag(SemaRef.LoopWorkerClauseLoc,
+                 diag::note_acc_previous_clause_here);
+    return nullptr;
+  }
+
   return OpenACCGangClause::Create(Ctx, Clause.getBeginLoc(),
                                    Clause.getLParenLoc(), GangKinds, IntExprs,
                                    Clause.getEndLoc());
@@ -1216,6 +1312,7 @@ SemaOpenACC::AssociatedStmtRAII::AssociatedStmtRAII(
     ArrayRef<OpenACCClause *> Clauses)
     : SemaRef(S), OldActiveComputeConstructInfo(S.ActiveComputeConstructInfo),
       DirKind(DK), OldLoopGangClauseOnKernelLoc(S.LoopGangClauseOnKernelLoc),
+      OldLoopWorkerClauseLoc(S.LoopWorkerClauseLoc),
       LoopRAII(SemaRef, /*PreserveDepth=*/false) {
   // Compute constructs end up taking their 'loop'.
   if (DirKind == OpenACCDirectiveKind::Parallel ||
@@ -1232,6 +1329,7 @@ SemaOpenACC::AssociatedStmtRAII::AssociatedStmtRAII(
     //
     // Implement the 'unless within a nested compute region' part.
     SemaRef.LoopGangClauseOnKernelLoc = {};
+    SemaRef.LoopWorkerClauseLoc = {};
   } else if (DirKind == OpenACCDirectiveKind::Loop) {
     SetCollapseInfoBeforeAssociatedStmt(UnInstClauses, Clauses);
     SetTileInfoBeforeAssociatedStmt(UnInstClauses, Clauses);
@@ -1252,6 +1350,12 @@ SemaOpenACC::AssociatedStmtRAII::AssociatedStmtRAII(
       if (Itr != Clauses.end())
         SemaRef.LoopGangClauseOnKernelLoc = (*Itr)->getBeginLoc();
     }
+
+    if (UnI...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2024

@llvm/pr-subscribers-clang-modules

Author: Erich Keane (erichkeane)

Changes

The worker clause specifies iterations of the loop/ that are executed in parallel by distributing the iterations among the multiple works within a single gang.

The sema rules for this type are simply that it cannot be combined with a kernel construct with a num_workers clause, child loop clauses cannot contain a gang or worker clause, and that the argument is oly allowed when associated with a kernel.


Patch is 49.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112206.diff

20 Files Affected:

  • (modified) clang/include/clang/AST/OpenACCClause.h (+16-26)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+11-10)
  • (modified) clang/include/clang/Basic/OpenACCClauses.def (+1)
  • (modified) clang/include/clang/Sema/SemaOpenACC.h (+16)
  • (modified) clang/lib/AST/OpenACCClause.cpp (+27-3)
  • (modified) clang/lib/AST/StmtProfile.cpp (+6)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+1)
  • (modified) clang/lib/Parse/ParseOpenACC.cpp (+1)
  • (modified) clang/lib/Sema/SemaOpenACC.cpp (+115-10)
  • (modified) clang/lib/Sema/TreeTransform.h (+28)
  • (modified) clang/lib/Serialization/ASTReader.cpp (+6-1)
  • (modified) clang/lib/Serialization/ASTWriter.cpp (+8-1)
  • (modified) clang/test/AST/ast-print-openacc-loop-construct.cpp (+39)
  • (modified) clang/test/ParserOpenACC/parse-clauses.c (+9-10)
  • (modified) clang/test/SemaOpenACC/compute-construct-device_type-clause.c (+1-2)
  • (modified) clang/test/SemaOpenACC/loop-construct-auto_seq_independent-clauses.c (+4-11)
  • (modified) clang/test/SemaOpenACC/loop-construct-device_type-clause.c (-1)
  • (added) clang/test/SemaOpenACC/loop-construct-worker-ast.cpp (+270)
  • (added) clang/test/SemaOpenACC/loop-construct-worker-clause.cpp (+202)
  • (modified) clang/tools/libclang/CIndex.cpp (+5)
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index f3a09eb651458d..e8b8f477f91ae7 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -145,32 +145,6 @@ class OpenACCVectorClause : public OpenACCClause {
   }
 };
 
-// Not yet implemented, but the type name is necessary for 'seq' diagnostics, so
-// this provides a basic, do-nothing implementation. We still need to add this
-// type to the visitors/etc, as well as get it to take its proper arguments.
-class OpenACCWorkerClause : public OpenACCClause {
-protected:
-  OpenACCWorkerClause(SourceLocation BeginLoc, SourceLocation EndLoc)
-      : OpenACCClause(OpenACCClauseKind::Worker, BeginLoc, EndLoc) {
-    llvm_unreachable("Not yet implemented");
-  }
-
-public:
-  static bool classof(const OpenACCClause *C) {
-    return C->getClauseKind() == OpenACCClauseKind::Worker;
-  }
-
-  static OpenACCWorkerClause *
-  Create(const ASTContext &Ctx, SourceLocation BeginLoc, SourceLocation EndLoc);
-
-  child_range children() {
-    return child_range(child_iterator(), child_iterator());
-  }
-  const_child_range children() const {
-    return const_child_range(const_child_iterator(), const_child_iterator());
-  }
-};
-
 /// Represents a clause that has a list of parameters.
 class OpenACCClauseWithParams : public OpenACCClause {
   /// Location of the '('.
@@ -541,6 +515,22 @@ class OpenACCGangClause final
          ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);
 };
 
+class OpenACCWorkerClause : public OpenACCClauseWithSingleIntExpr {
+protected:
+  OpenACCWorkerClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                      Expr *IntExpr, SourceLocation EndLoc);
+
+public:
+  static bool classof(const OpenACCClause *C) {
+    return C->getClauseKind() == OpenACCClauseKind::Worker;
+  }
+
+  static OpenACCWorkerClause *Create(const ASTContext &Ctx,
+                                     SourceLocation BeginLoc,
+                                     SourceLocation LParenLoc, Expr *IntExpr,
+                                     SourceLocation EndLoc);
+};
+
 class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
   OpenACCNumWorkersClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
                           Expr *IntExpr, SourceLocation EndLoc);
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 3c62a017005e59..00201a7f192f17 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12686,21 +12686,22 @@ def err_acc_intervening_code
 def err_acc_gang_multiple_elt
     : Error<"OpenACC 'gang' clause may have at most one %select{unnamed or "
             "'num'|'dim'|'static'}0 argument">;
-def err_acc_gang_arg_invalid
-    : Error<"'%0' argument on 'gang' clause is not permitted on a%select{n "
-            "orphaned|||}1 'loop' construct %select{|associated with a "
+def err_acc_int_arg_invalid
+    : Error<"'%1' argument on '%0' clause is not permitted on a%select{n "
+            "orphaned|||}2 'loop' construct %select{|associated with a "
             "'parallel' compute construct|associated with a 'kernels' compute "
-            "construct|associated with a 'serial' compute construct}1">;
+            "construct|associated with a 'serial' compute construct}2">;
 def err_acc_gang_dim_value
     : Error<"argument to 'gang' clause dimension must be %select{a constant "
             "expression|1, 2, or 3: evaluated to %1}0">;
-def err_acc_gang_num_gangs_conflict
-    : Error<"'num' argument to 'gang' clause not allowed on a 'loop' construct "
-            "associated with a 'kernels' construct that has a 'num_gangs' "
+def err_acc_num_arg_conflict
+    : Error<"'num' argument to '%0' clause not allowed on a 'loop' construct "
+            "associated with a 'kernels' construct that has a "
+            "'%select{num_gangs|num_workers}1' "
             "clause">;
-def err_acc_gang_inside_gang
-    : Error<"loop with a 'gang' clause may not exist in the region of a 'gang' "
-            "clause on a 'kernels' compute construct">;
+def err_acc_clause_in_clause_region
+    : Error<"loop with a '%0' clause may not exist in the region of a '%1' "
+            "clause%select{| on a 'kernels' compute construct}2">;
 
 // AMDGCN builtins diagnostics
 def err_amdgcn_global_load_lds_size_invalid_value : Error<"invalid size value">;
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
index 2a098de31eb618..4c0b56dc13e625 100644
--- a/clang/include/clang/Basic/OpenACCClauses.def
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -56,6 +56,7 @@ VISIT_CLAUSE(Seq)
 VISIT_CLAUSE(Tile)
 VISIT_CLAUSE(VectorLength)
 VISIT_CLAUSE(Wait)
+VISIT_CLAUSE(Worker)
 
 #undef VISIT_CLAUSE
 #undef CLAUSE_ALIAS
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index 59a9648d5f9380..e253610a84b0bf 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -118,6 +118,11 @@ class SemaOpenACC : public SemaBase {
   /// 'kernel' construct, this will have the source location for it. This
   /// permits us to implement the restriction of no further 'gang' clauses.
   SourceLocation LoopGangClauseOnKernelLoc;
+  /// If there is a current 'active' loop construct with a 'worker' clause on it
+  /// (on any sort of construct), this has the source location for it.  This
+  /// permits us to implement the restriction of no further 'gang' or 'worker'
+  /// clauses.
+  SourceLocation LoopWorkerClauseLoc;
 
   // Redeclaration of the version in OpenACCClause.h.
   using DeviceTypeArgument = std::pair<IdentifierInfo *, SourceLocation>;
@@ -224,11 +229,15 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::Async ||
               ClauseKind == OpenACCClauseKind::Tile ||
+              ClauseKind == OpenACCClauseKind::Worker ||
+              ClauseKind == OpenACCClauseKind::Vector ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
 
       // 'async' and 'wait' have an optional IntExpr, so be tolerant of that.
       if ((ClauseKind == OpenACCClauseKind::Async ||
+           ClauseKind == OpenACCClauseKind::Worker ||
+           ClauseKind == OpenACCClauseKind::Vector ||
            ClauseKind == OpenACCClauseKind::Wait) &&
           std::holds_alternative<std::monostate>(Details))
         return 0;
@@ -271,6 +280,8 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::Async ||
               ClauseKind == OpenACCClauseKind::Tile ||
               ClauseKind == OpenACCClauseKind::Gang ||
+              ClauseKind == OpenACCClauseKind::Worker ||
+              ClauseKind == OpenACCClauseKind::Vector ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
 
@@ -401,6 +412,8 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::Async ||
               ClauseKind == OpenACCClauseKind::Tile ||
+              ClauseKind == OpenACCClauseKind::Worker ||
+              ClauseKind == OpenACCClauseKind::Vector ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       Details = IntExprDetails{{IntExprs.begin(), IntExprs.end()}};
@@ -410,6 +423,8 @@ class SemaOpenACC : public SemaBase {
               ClauseKind == OpenACCClauseKind::NumWorkers ||
               ClauseKind == OpenACCClauseKind::Async ||
               ClauseKind == OpenACCClauseKind::Tile ||
+              ClauseKind == OpenACCClauseKind::Worker ||
+              ClauseKind == OpenACCClauseKind::Vector ||
               ClauseKind == OpenACCClauseKind::VectorLength) &&
              "Parsed clause kind does not have a int exprs");
       Details = IntExprDetails{std::move(IntExprs)};
@@ -663,6 +678,7 @@ class SemaOpenACC : public SemaBase {
     ComputeConstructInfo OldActiveComputeConstructInfo;
     OpenACCDirectiveKind DirKind;
     SourceLocation OldLoopGangClauseOnKernelLoc;
+    SourceLocation OldLoopWorkerClauseLoc;
     llvm::SmallVector<OpenACCLoopConstruct *> ParentlessLoopConstructs;
     LoopInConstructRAII LoopRAII;
 
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index 6fb8fe0b8cfeef..638252fd811f1d 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -44,7 +44,8 @@ bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
 bool OpenACCClauseWithSingleIntExpr::classof(const OpenACCClause *C) {
   return OpenACCNumWorkersClause::classof(C) ||
          OpenACCVectorLengthClause::classof(C) ||
-         OpenACCCollapseClause::classof(C) || OpenACCAsyncClause::classof(C);
+         OpenACCWorkerClause::classof(C) || OpenACCCollapseClause::classof(C) ||
+         OpenACCAsyncClause::classof(C);
 }
 OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
                                                    OpenACCDefaultClauseKind K,
@@ -403,11 +404,24 @@ OpenACCGangClause::Create(const ASTContext &C, SourceLocation BeginLoc,
       OpenACCGangClause(BeginLoc, LParenLoc, GangKinds, IntExprs, EndLoc);
 }
 
+OpenACCWorkerClause::OpenACCWorkerClause(SourceLocation BeginLoc,
+                                         SourceLocation LParenLoc,
+                                         Expr *IntExpr, SourceLocation EndLoc)
+    : OpenACCClauseWithSingleIntExpr(OpenACCClauseKind::Worker, BeginLoc,
+                                     LParenLoc, IntExpr, EndLoc) {
+  assert((!IntExpr || IntExpr->isInstantiationDependent() ||
+          IntExpr->getType()->isIntegerType()) &&
+         "Int expression type not scalar/dependent");
+}
+
 OpenACCWorkerClause *OpenACCWorkerClause::Create(const ASTContext &C,
                                                  SourceLocation BeginLoc,
+                                                 SourceLocation LParenLoc,
+                                                 Expr *IntExpr,
                                                  SourceLocation EndLoc) {
-  void *Mem = C.Allocate(sizeof(OpenACCWorkerClause));
-  return new (Mem) OpenACCWorkerClause(BeginLoc, EndLoc);
+  void *Mem =
+      C.Allocate(sizeof(OpenACCWorkerClause), alignof(OpenACCWorkerClause));
+  return new (Mem) OpenACCWorkerClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
 }
 
 OpenACCVectorClause *OpenACCVectorClause::Create(const ASTContext &C,
@@ -638,3 +652,13 @@ void OpenACCClausePrinter::VisitGangClause(const OpenACCGangClause &C) {
     OS << ")";
   }
 }
+
+void OpenACCClausePrinter::VisitWorkerClause(const OpenACCWorkerClause &C) {
+  OS << "worker";
+
+  if (C.hasIntExpr()) {
+    OS << "(num: ";
+    printExpr(C.getIntExpr());
+    OS << ")";
+  }
+}
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 6161b1403ed35d..25b1cbb8590869 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2629,6 +2629,12 @@ void OpenACCClauseProfiler::VisitAsyncClause(const OpenACCAsyncClause &Clause) {
     Profiler.VisitStmt(Clause.getIntExpr());
 }
 
+void OpenACCClauseProfiler::VisitWorkerClause(
+    const OpenACCWorkerClause &Clause) {
+  if (Clause.hasIntExpr())
+    Profiler.VisitStmt(Clause.getIntExpr());
+}
+
 void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
   if (Clause.hasDevNumExpr())
     Profiler.VisitStmt(Clause.getDevNumExpr());
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index ac8c196777f9b8..beccb0615f0e9c 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -420,6 +420,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
     case OpenACCClauseKind::Self:
     case OpenACCClauseKind::Seq:
     case OpenACCClauseKind::Tile:
+    case OpenACCClauseKind::Worker:
     case OpenACCClauseKind::VectorLength:
       // The condition expression will be printed as a part of the 'children',
       // but print 'clause' here so it is clear what is happening from the dump.
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index 635039b724e6a0..51d4dc38c17f67 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -1130,6 +1130,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
           Parens.skipToEnd();
           return OpenACCCanContinue();
         }
+        ParsedClause.setIntExprDetails(IntExpr.get());
         break;
       }
       case OpenACCClauseKind::Async: {
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index 30d73d621db69b..1b24331cbd87ca 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -377,6 +377,18 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
     default:
       return false;
     }
+  case OpenACCClauseKind::Worker: {
+    switch (DirectiveKind) {
+    case OpenACCDirectiveKind::Loop:
+    case OpenACCDirectiveKind::ParallelLoop:
+    case OpenACCDirectiveKind::SerialLoop:
+    case OpenACCDirectiveKind::KernelsLoop:
+    case OpenACCDirectiveKind::Routine:
+      return true;
+    default:
+      return false;
+    }
+  }
   }
 
   default:
@@ -500,7 +512,6 @@ class SemaOpenACCClauseVisitor {
 
   OpenACCClause *Visit(SemaOpenACC::OpenACCParsedClause &Clause) {
     switch (Clause.getClauseKind()) {
-    case OpenACCClauseKind::Worker:
     case OpenACCClauseKind::Vector: {
       // TODO OpenACC: These are only implemented enough for the 'seq'
       // diagnostic, otherwise treats itself as unimplemented.  When we
@@ -1024,6 +1035,75 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitIndependentClause(
                                           Clause.getEndLoc());
 }
 
+OpenACCClause *SemaOpenACCClauseVisitor::VisitWorkerClause(
+    SemaOpenACC::OpenACCParsedClause &Clause) {
+  if (DiagIfSeqClause(Clause))
+    return nullptr;
+
+  // Restrictions only properly implemented on 'loop' constructs, and it is
+  // the only construct that can do anything with this, so skip/treat as
+  // unimplemented for the combined constructs.
+  if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Loop)
+    return isNotImplemented();
+
+  Expr *IntExpr =
+      Clause.getNumIntExprs() != 0 ? Clause.getIntExprs()[0] : nullptr;
+
+  if (IntExpr) {
+    switch (SemaRef.getActiveComputeConstructInfo().Kind) {
+    case OpenACCDirectiveKind::Invalid:
+      SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_int_arg_invalid)
+          << OpenACCClauseKind::Worker << "num" << /*orphan=*/0;
+      IntExpr = nullptr;
+      break;
+    case OpenACCDirectiveKind::Parallel:
+      SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_int_arg_invalid)
+          << OpenACCClauseKind::Worker << "num" << /*parallel=*/1;
+      IntExpr = nullptr;
+      break;
+    case OpenACCDirectiveKind::Serial:
+      SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_int_arg_invalid)
+          << OpenACCClauseKind::Worker << "num" << /*serial=*/3;
+      IntExpr = nullptr;
+      break;
+    case OpenACCDirectiveKind::Kernels: {
+      const auto *Itr =
+          llvm::find_if(SemaRef.getActiveComputeConstructInfo().Clauses,
+                        llvm::IsaPred<OpenACCNumWorkersClause>);
+      if (Itr != SemaRef.getActiveComputeConstructInfo().Clauses.end()) {
+        SemaRef.Diag(IntExpr->getBeginLoc(), diag::err_acc_num_arg_conflict)
+            << OpenACCClauseKind::Worker << /*num_workers=*/1;
+        SemaRef.Diag((*Itr)->getBeginLoc(),
+                     diag::note_acc_previous_clause_here);
+
+        IntExpr = nullptr;
+      }
+      break;
+    }
+    default:
+      llvm_unreachable("Non compute construct in active compute construct");
+    }
+  }
+
+  // OpenACC 3.3 2.9.3: The region of a loop with a 'worker' clause may not
+  // contain a loop with a gang or worker clause unless within a nested compute
+  // region.
+  if (SemaRef.LoopWorkerClauseLoc.isValid()) {
+    // This handles the 'inner loop' diagnostic, but we cannot set that we're on
+    // one of these until we get to the end of the construct.
+    SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
+        << OpenACCClauseKind::Worker << OpenACCClauseKind::Worker
+        << /*skip kernels construct info*/ 0;
+    SemaRef.Diag(SemaRef.LoopWorkerClauseLoc,
+                 diag::note_acc_previous_clause_here);
+    return nullptr;
+  }
+
+  return OpenACCWorkerClause::Create(Ctx, Clause.getBeginLoc(),
+                                     Clause.getLParenLoc(), IntExpr,
+                                     Clause.getEndLoc());
+}
+
 OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause(
     SemaOpenACC::OpenACCParsedClause &Clause) {
   if (DiagIfSeqClause(Clause))
@@ -1061,8 +1141,8 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause(
                         llvm::IsaPred<OpenACCNumGangsClause>);
 
       if (Itr != SemaRef.getActiveComputeConstructInfo().Clauses.end()) {
-        SemaRef.Diag(ER.get()->getBeginLoc(),
-                     diag::err_acc_gang_num_gangs_conflict);
+        SemaRef.Diag(ER.get()->getBeginLoc(), diag::err_acc_num_arg_conflict)
+            << OpenACCClauseKind::Gang << /*num_gangs=*/0;
         SemaRef.Diag((*Itr)->getBeginLoc(),
                      diag::note_acc_previous_clause_here);
         continue;
@@ -1091,12 +1171,28 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause(
   if (SemaRef.LoopGangClauseOnKernelLoc.isValid()) {
     // This handles the 'inner loop' diagnostic, but we cannot set that we're on
     // one of these until we get to the end of the construct.
-    SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_gang_inside_gang);
+    SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
+        << OpenACCClauseKind::Gang << OpenACCClauseKind::Gang
+        << /*kernels construct info*/ 1;
     SemaRef.Diag(SemaRef.LoopGangClauseOnKernelLoc,
                  diag::note_acc_previous_clause_here);
     return nullptr;
   }
 
+  // OpenACC 3.3 2.9.3: The region of a loop with a 'worker' clause may not
+  // contain a loop with a gang or worker clause unless within a nested compute
+  // region.
+  if (SemaRef.LoopWorkerClauseLoc.isValid()) {
+    // This handles the 'inner loop' diagnostic, but we cannot set that we're on
+    // one of these until we get to the end of the construct.
+    SemaRef.Diag(Clause.getBeginLoc(), diag::err_acc_clause_in_clause_region)
+        << OpenACCClauseKind::Gang << OpenACCClauseKind::Worker
+        << /*kernels construct info*/ 1;
+    SemaRef.Diag(SemaRef.LoopWorkerClauseLoc,
+                 diag::note_acc_previous_clause_here);
+    return nullptr;
+  }
+
   return OpenACCGangClause::Create(Ctx, Clause.getBeginLoc(),
                                    Clause.getLParenLoc(), GangKinds, IntExprs,
                                    Clause.getEndLoc());
@@ -1216,6 +1312,7 @@ SemaOpenACC::AssociatedStmtRAII::AssociatedStmtRAII(
     ArrayRef<OpenACCClause *> Clauses)
     : SemaRef(S), OldActiveComputeConstructInfo(S.ActiveComputeConstructInfo),
       DirKind(DK), OldLoopGangClauseOnKernelLoc(S.LoopGangClauseOnKernelLoc),
+      OldLoopWorkerClauseLoc(S.LoopWorkerClauseLoc),
       LoopRAII(SemaRef, /*PreserveDepth=*/false) {
   // Compute constructs end up taking their 'loop'.
   if (DirKind == OpenACCDirectiveKind::Parallel ||
@@ -1232,6 +1329,7 @@ SemaOpenACC::AssociatedStmtRAII::AssociatedStmtRAII(
     //
     // Implement the 'unless within a nested compute region' part.
     SemaRef.LoopGangClauseOnKernelLoc = {};
+    SemaRef.LoopWorkerClauseLoc = {};
   } else if (DirKind == OpenACCDirectiveKind::Loop) {
     SetCollapseInfoBeforeAssociatedStmt(UnInstClauses, Clauses);
     SetTileInfoBeforeAssociatedStmt(UnInstClauses, Clauses);
@@ -1252,6 +1350,12 @@ SemaOpenACC::AssociatedStmtRAII::AssociatedStmtRAII(
       if (Itr != Clauses.end())
         SemaRef.LoopGangClauseOnKernelLoc = (*Itr)->getBeginLoc();
     }
+
+    if (UnI...
[truncated]

@erichkeane erichkeane merged commit cf456ed into llvm:main Oct 14, 2024
15 checks passed
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
The worker clause specifies iterations of the loop/ that are executed in
parallel by distributing the iterations among the multiple works within
a single gang.

The sema rules for this type are simply that it cannot be combined with
a `kernel` construct with a `num_workers` clause, child `loop` clauses
cannot contain a `gang` or `worker` clause, and that the argument is oly
allowed when associated with a `kernel`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:as-a-library libclang and C++ API clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants