@@ -9815,6 +9815,25 @@ static Stmt *buildPreInits(ASTContext &Context,
9815
9815
return nullptr;
9816
9816
}
9817
9817
9818
+ /// Append the \p Item or the content of a CompoundStmt to the list \p
9819
+ /// TargetList.
9820
+ ///
9821
+ /// A CompoundStmt is used as container in case multiple statements need to be
9822
+ /// stored in lieu of using an explicit list. Flattening is necessary because
9823
+ /// contained DeclStmts need to be visible after the execution of the list. Used
9824
+ /// for OpenMP pre-init declarations/statements.
9825
+ static void appendFlattendedStmtList(SmallVectorImpl<Stmt *> &TargetList,
9826
+ Stmt *Item) {
9827
+ // nullptr represents an empty list.
9828
+ if (!Item)
9829
+ return;
9830
+
9831
+ if (auto *CS = dyn_cast<CompoundStmt>(Item))
9832
+ llvm::append_range(TargetList, CS->body());
9833
+ else
9834
+ TargetList.push_back(Item);
9835
+ }
9836
+
9818
9837
/// Build preinits statement for the given declarations.
9819
9838
static Stmt *
9820
9839
buildPreInits(ASTContext &Context,
@@ -9828,6 +9847,17 @@ buildPreInits(ASTContext &Context,
9828
9847
return nullptr;
9829
9848
}
9830
9849
9850
+ /// Build pre-init statement for the given statements.
9851
+ static Stmt *buildPreInits(ASTContext &Context, ArrayRef<Stmt *> PreInits) {
9852
+ if (PreInits.empty())
9853
+ return nullptr;
9854
+
9855
+ SmallVector<Stmt *> Stmts;
9856
+ for (Stmt *S : PreInits)
9857
+ appendFlattendedStmtList(Stmts, S);
9858
+ return CompoundStmt::Create(Context, PreInits, FPOptionsOverride(), {}, {});
9859
+ }
9860
+
9831
9861
/// Build postupdate expression for the given list of postupdates expressions.
9832
9862
static Expr *buildPostUpdate(Sema &S, ArrayRef<Expr *> PostUpdates) {
9833
9863
Expr *PostUpdate = nullptr;
@@ -9924,11 +9954,21 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
9924
9954
Stmt *DependentPreInits = Transform->getPreInits();
9925
9955
if (!DependentPreInits)
9926
9956
return;
9927
- for (Decl *C : cast<DeclStmt>(DependentPreInits)->getDeclGroup()) {
9928
- auto *D = cast<VarDecl>(C);
9929
- DeclRefExpr *Ref = buildDeclRefExpr(SemaRef, D, D->getType(),
9930
- Transform->getBeginLoc());
9931
- Captures[Ref] = Ref;
9957
+
9958
+ // Search for pre-init declared variables that need to be captured
9959
+ // to be referenceable inside the directive.
9960
+ SmallVector<Stmt *> Constituents;
9961
+ appendFlattendedStmtList(Constituents, DependentPreInits);
9962
+ for (Stmt *S : Constituents) {
9963
+ if (auto *DC = dyn_cast<DeclStmt>(S)) {
9964
+ for (Decl *C : DC->decls()) {
9965
+ auto *D = cast<VarDecl>(C);
9966
+ DeclRefExpr *Ref = buildDeclRefExpr(
9967
+ SemaRef, D, D->getType().getNonReferenceType(),
9968
+ Transform->getBeginLoc());
9969
+ Captures[Ref] = Ref;
9970
+ }
9971
+ }
9932
9972
}
9933
9973
}))
9934
9974
return 0;
@@ -15059,9 +15099,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
15059
15099
bool SemaOpenMP::checkTransformableLoopNest(
15060
15100
OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops,
15061
15101
SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
15062
- Stmt *&Body,
15063
- SmallVectorImpl<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>>
15064
- &OriginalInits) {
15102
+ Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits) {
15065
15103
OriginalInits.emplace_back();
15066
15104
bool Result = OMPLoopBasedDirective::doForAllLoops(
15067
15105
AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, NumLoops,
@@ -15095,16 +15133,70 @@ bool SemaOpenMP::checkTransformableLoopNest(
15095
15133
DependentPreInits = Dir->getPreInits();
15096
15134
else
15097
15135
llvm_unreachable("Unhandled loop transformation");
15098
- if (!DependentPreInits)
15099
- return;
15100
- llvm::append_range(OriginalInits.back(),
15101
- cast<DeclStmt>(DependentPreInits)->getDeclGroup());
15136
+
15137
+ appendFlattendedStmtList(OriginalInits.back(), DependentPreInits);
15102
15138
});
15103
15139
assert(OriginalInits.back().empty() && "No preinit after innermost loop");
15104
15140
OriginalInits.pop_back();
15105
15141
return Result;
15106
15142
}
15107
15143
15144
+ /// Add preinit statements that need to be propageted from the selected loop.
15145
+ static void addLoopPreInits(ASTContext &Context,
15146
+ OMPLoopBasedDirective::HelperExprs &LoopHelper,
15147
+ Stmt *LoopStmt, ArrayRef<Stmt *> OriginalInit,
15148
+ SmallVectorImpl<Stmt *> &PreInits) {
15149
+
15150
+ // For range-based for-statements, ensure that their syntactic sugar is
15151
+ // executed by adding them as pre-init statements.
15152
+ if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) {
15153
+ Stmt *RangeInit = CXXRangeFor->getInit();
15154
+ if (RangeInit)
15155
+ PreInits.push_back(RangeInit);
15156
+
15157
+ DeclStmt *RangeStmt = CXXRangeFor->getRangeStmt();
15158
+ PreInits.push_back(new (Context) DeclStmt(RangeStmt->getDeclGroup(),
15159
+ RangeStmt->getBeginLoc(),
15160
+ RangeStmt->getEndLoc()));
15161
+
15162
+ DeclStmt *RangeEnd = CXXRangeFor->getEndStmt();
15163
+ PreInits.push_back(new (Context) DeclStmt(RangeEnd->getDeclGroup(),
15164
+ RangeEnd->getBeginLoc(),
15165
+ RangeEnd->getEndLoc()));
15166
+ }
15167
+
15168
+ llvm::append_range(PreInits, OriginalInit);
15169
+
15170
+ // List of OMPCapturedExprDecl, for __begin, __end, and NumIterations
15171
+ if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) {
15172
+ PreInits.push_back(new (Context) DeclStmt(
15173
+ PI->getDeclGroup(), PI->getBeginLoc(), PI->getEndLoc()));
15174
+ }
15175
+
15176
+ // Gather declarations for the data members used as counters.
15177
+ for (Expr *CounterRef : LoopHelper.Counters) {
15178
+ auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15179
+ if (isa<OMPCapturedExprDecl>(CounterDecl))
15180
+ PreInits.push_back(new (Context) DeclStmt(
15181
+ DeclGroupRef(CounterDecl), SourceLocation(), SourceLocation()));
15182
+ }
15183
+ }
15184
+
15185
+ /// Collect the loop statements (ForStmt or CXXRangeForStmt) of the affected
15186
+ /// loop of a construct.
15187
+ static void collectLoopStmts(Stmt *AStmt, MutableArrayRef<Stmt *> LoopStmts) {
15188
+ size_t NumLoops = LoopStmts.size();
15189
+ OMPLoopBasedDirective::doForAllLoops(
15190
+ AStmt, /*TryImperfectlyNestedLoops=*/false, NumLoops,
15191
+ [LoopStmts](unsigned Cnt, Stmt *CurStmt) {
15192
+ assert(!LoopStmts[Cnt] && "Loop statement must not yet be assigned");
15193
+ LoopStmts[Cnt] = CurStmt;
15194
+ return false;
15195
+ });
15196
+ assert(!is_contained(LoopStmts, nullptr) &&
15197
+ "Expecting a loop statement for each affected loop");
15198
+ }
15199
+
15108
15200
StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15109
15201
Stmt *AStmt,
15110
15202
SourceLocation StartLoc,
@@ -15126,8 +15218,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15126
15218
// Verify and diagnose loop nest.
15127
15219
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
15128
15220
Stmt *Body = nullptr;
15129
- SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, 4>
15130
- OriginalInits;
15221
+ SmallVector<SmallVector<Stmt *, 0>, 4> OriginalInits;
15131
15222
if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body,
15132
15223
OriginalInits))
15133
15224
return StmtError();
@@ -15144,7 +15235,11 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15144
15235
"Expecting loop iteration space dimensionality to match number of "
15145
15236
"affected loops");
15146
15237
15147
- SmallVector<Decl *, 4> PreInits;
15238
+ // Collect all affected loop statements.
15239
+ SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
15240
+ collectLoopStmts(AStmt, LoopStmts);
15241
+
15242
+ SmallVector<Stmt *, 4> PreInits;
15148
15243
CaptureVars CopyTransformer(SemaRef);
15149
15244
15150
15245
// Create iteration variables for the generated loops.
@@ -15184,20 +15279,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15184
15279
&SemaRef.PP.getIdentifierTable().get(TileCntName));
15185
15280
TileIndVars[I] = TileCntDecl;
15186
15281
}
15187
- for (auto &P : OriginalInits[I]) {
15188
- if (auto *D = P.dyn_cast<Decl *>())
15189
- PreInits.push_back(D);
15190
- else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
15191
- PreInits.append(PI->decl_begin(), PI->decl_end());
15192
- }
15193
- if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
15194
- PreInits.append(PI->decl_begin(), PI->decl_end());
15195
- // Gather declarations for the data members used as counters.
15196
- for (Expr *CounterRef : LoopHelper.Counters) {
15197
- auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15198
- if (isa<OMPCapturedExprDecl>(CounterDecl))
15199
- PreInits.push_back(CounterDecl);
15200
- }
15282
+
15283
+ addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I],
15284
+ PreInits);
15201
15285
}
15202
15286
15203
15287
// Once the original iteration values are set, append the innermost body.
@@ -15246,19 +15330,20 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15246
15330
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
15247
15331
Expr *NumIterations = LoopHelper.NumIterations;
15248
15332
auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
15249
- QualType CntTy = OrigCntVar->getType();
15333
+ QualType IVTy = NumIterations->getType();
15334
+ Stmt *LoopStmt = LoopStmts[I];
15250
15335
15251
15336
// Commonly used variables. One of the constraints of an AST is that every
15252
15337
// node object must appear at most once, hence we define lamdas that create
15253
15338
// a new AST node at every use.
15254
- auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, CntTy ,
15339
+ auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, IVTy ,
15255
15340
OrigCntVar]() {
15256
- return buildDeclRefExpr(SemaRef, TileIndVars[I], CntTy ,
15341
+ return buildDeclRefExpr(SemaRef, TileIndVars[I], IVTy ,
15257
15342
OrigCntVar->getExprLoc());
15258
15343
};
15259
- auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy ,
15344
+ auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy ,
15260
15345
OrigCntVar]() {
15261
- return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy ,
15346
+ return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy ,
15262
15347
OrigCntVar->getExprLoc());
15263
15348
};
15264
15349
@@ -15320,6 +15405,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15320
15405
// further into the inner loop.
15321
15406
SmallVector<Stmt *, 4> BodyParts;
15322
15407
BodyParts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
15408
+ if (auto *SourceCXXFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
15409
+ BodyParts.push_back(SourceCXXFor->getLoopVarStmt());
15323
15410
BodyParts.push_back(Inner);
15324
15411
Inner = CompoundStmt::Create(Context, BodyParts, FPOptionsOverride(),
15325
15412
Inner->getBeginLoc(), Inner->getEndLoc());
@@ -15334,12 +15421,14 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15334
15421
auto &LoopHelper = LoopHelpers[I];
15335
15422
Expr *NumIterations = LoopHelper.NumIterations;
15336
15423
DeclRefExpr *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
15337
- QualType CntTy = OrigCntVar ->getType();
15424
+ QualType IVTy = NumIterations ->getType();
15338
15425
15339
- // Commonly used variables.
15340
- auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy,
15426
+ // Commonly used variables. One of the constraints of an AST is that every
15427
+ // node object must appear at most once, hence we define lamdas that create
15428
+ // a new AST node at every use.
15429
+ auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy,
15341
15430
OrigCntVar]() {
15342
- return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy ,
15431
+ return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy ,
15343
15432
OrigCntVar->getExprLoc());
15344
15433
};
15345
15434
@@ -15405,8 +15494,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
15405
15494
Stmt *Body = nullptr;
15406
15495
SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
15407
15496
NumLoops);
15408
- SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, NumLoops + 1>
15409
- OriginalInits;
15497
+ SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
15410
15498
if (!checkTransformableLoopNest(OMPD_unroll, AStmt, NumLoops, LoopHelpers,
15411
15499
Body, OriginalInits))
15412
15500
return StmtError();
@@ -15418,6 +15506,10 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
15418
15506
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
15419
15507
NumGeneratedLoops, nullptr, nullptr);
15420
15508
15509
+ assert(LoopHelpers.size() == NumLoops &&
15510
+ "Expecting a single-dimensional loop iteration space");
15511
+ assert(OriginalInits.size() == NumLoops &&
15512
+ "Expecting a single-dimensional loop iteration space");
15421
15513
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
15422
15514
15423
15515
if (FullClause) {
@@ -15481,24 +15573,13 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
15481
15573
// of a canonical loop nest where these PreInits are emitted before the
15482
15574
// outermost directive.
15483
15575
15576
+ // Find the loop statement.
15577
+ Stmt *LoopStmt = nullptr;
15578
+ collectLoopStmts(AStmt, {LoopStmt});
15579
+
15484
15580
// Determine the PreInit declarations.
15485
- SmallVector<Decl *, 4> PreInits;
15486
- assert(OriginalInits.size() == 1 &&
15487
- "Expecting a single-dimensional loop iteration space");
15488
- for (auto &P : OriginalInits[0]) {
15489
- if (auto *D = P.dyn_cast<Decl *>())
15490
- PreInits.push_back(D);
15491
- else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
15492
- PreInits.append(PI->decl_begin(), PI->decl_end());
15493
- }
15494
- if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
15495
- PreInits.append(PI->decl_begin(), PI->decl_end());
15496
- // Gather declarations for the data members used as counters.
15497
- for (Expr *CounterRef : LoopHelper.Counters) {
15498
- auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15499
- if (isa<OMPCapturedExprDecl>(CounterDecl))
15500
- PreInits.push_back(CounterDecl);
15501
- }
15581
+ SmallVector<Stmt *, 4> PreInits;
15582
+ addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);
15502
15583
15503
15584
auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
15504
15585
QualType IVTy = IterationVarRef->getType();
@@ -15604,6 +15685,8 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
15604
15685
// Inner For statement.
15605
15686
SmallVector<Stmt *> InnerBodyStmts;
15606
15687
InnerBodyStmts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
15688
+ if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
15689
+ InnerBodyStmts.push_back(CXXRangeFor->getLoopVarStmt());
15607
15690
InnerBodyStmts.push_back(Body);
15608
15691
CompoundStmt *InnerBody =
15609
15692
CompoundStmt::Create(getASTContext(), InnerBodyStmts, FPOptionsOverride(),
0 commit comments