Skip to content

Commit d2e88db

Browse files
Applied remarks
1 parent 6176c95 commit d2e88db

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2591,17 +2591,24 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
25912591

25922592
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
25932593

2594+
// Check if we can generate no-loop kernel
25942595
bool noLoopMode = false;
25952596
omp::TargetOp targetOp = wsloopOp->getParentOfType<mlir::omp::TargetOp>();
25962597
if (targetOp) {
25972598
Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
2598-
omp::TargetRegionFlags kernelFlags =
2599-
targetOp.getKernelExecFlags(targetCapturedOp);
2600-
if (omp::bitEnumContainsAll(kernelFlags,
2601-
omp::TargetRegionFlags::spmd |
2602-
omp::TargetRegionFlags::no_loop) &&
2603-
!omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
2604-
noLoopMode = true;
2599+
// We need this check because, without it, noLoopMode would be set to true
2600+
// for every omp.wsloop nested inside a no-loop SPMD target region, even if
2601+
// that loop is not the top-level SPMD one.
2602+
if (loopOp == targetCapturedOp) {
2603+
omp::TargetRegionFlags kernelFlags =
2604+
targetOp.getKernelExecFlags(targetCapturedOp);
2605+
if (omp::bitEnumContainsAll(kernelFlags,
2606+
omp::TargetRegionFlags::spmd |
2607+
omp::TargetRegionFlags::no_loop) &&
2608+
!omp::bitEnumContainsAny(kernelFlags,
2609+
omp::TargetRegionFlags::generic))
2610+
noLoopMode = true;
2611+
}
26052612
}
26062613

26072614
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =

0 commit comments

Comments
 (0)