@@ -2591,17 +2591,24 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
25912591
25922592 builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
25932593
2594+ // Check if we can generate no-loop kernel
25942595 bool noLoopMode = false ;
25952596 omp::TargetOp targetOp = wsloopOp->getParentOfType <mlir::omp::TargetOp>();
25962597 if (targetOp) {
25972598 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp ();
2598- omp::TargetRegionFlags kernelFlags =
2599- targetOp.getKernelExecFlags (targetCapturedOp);
2600- if (omp::bitEnumContainsAll (kernelFlags,
2601- omp::TargetRegionFlags::spmd |
2602- omp::TargetRegionFlags::no_loop) &&
2603- !omp::bitEnumContainsAny (kernelFlags, omp::TargetRegionFlags::generic))
2604- noLoopMode = true ;
2599+ // We need this check because, without it, noLoopMode would be set to true
2600+ // for every omp.wsloop nested inside a no-loop SPMD target region, even if
2601+ // that loop is not the top-level SPMD one.
2602+ if (loopOp == targetCapturedOp) {
2603+ omp::TargetRegionFlags kernelFlags =
2604+ targetOp.getKernelExecFlags (targetCapturedOp);
2605+ if (omp::bitEnumContainsAll (kernelFlags,
2606+ omp::TargetRegionFlags::spmd |
2607+ omp::TargetRegionFlags::no_loop) &&
2608+ !omp::bitEnumContainsAny (kernelFlags,
2609+ omp::TargetRegionFlags::generic))
2610+ noLoopMode = true ;
2611+ }
26052612 }
26062613
26072614 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
0 commit comments