From a18ac6599fb602c8b0abb17b9887a39c51d84e8e Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Fri, 8 Nov 2024 12:46:34 +0000 Subject: [PATCH] [MLIR][OpenMP] Prevent loop wrapper translation crashes This patch updates the `convertOmpOpRegions` translation function to prevent calling it for a loop wrapper region from causing a compiler crash due to a lack of terminator operations. This problem is currently not triggered because there are no cases for which the region of a loop wrapper is passed to that function. This might have to change in order to support composite construct translation to LLVM IR. --- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index da11ee9960e1f..b507fa656d601 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -391,6 +391,8 @@ static llvm::Expected convertOmpOpRegions( Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, SmallVectorImpl *continuationBlockPHIs = nullptr) { + bool isLoopWrapper = isa(region.getParentOp()); + llvm::BasicBlock *continuationBlock = splitBB(builder, true, "omp.region.cont"); llvm::BasicBlock *sourceBlock = builder.GetInsertBlock(); @@ -407,30 +409,34 @@ static llvm::Expected convertOmpOpRegions( // Terminators (namely YieldOp) may be forwarding values to the region that // need to be available in the continuation block. Collect the types of these - // operands in preparation of creating PHI nodes. + // operands in preparation of creating PHI nodes. This is skipped for loop + // wrapper operations, for which we know in advance they have no terminators. SmallVector continuationBlockPHITypes; - bool operandsProcessed = false; unsigned numYields = 0; - for (Block &bb : region.getBlocks()) { - if (omp::YieldOp yield = dyn_cast(bb.getTerminator())) { - if (!operandsProcessed) { - for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { - continuationBlockPHITypes.push_back( - moduleTranslation.convertType(yield->getOperand(i).getType())); - } - operandsProcessed = true; - } else { - assert(continuationBlockPHITypes.size() == yield->getNumOperands() && - "mismatching number of values yielded from the region"); - for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { - llvm::Type *operandType = - moduleTranslation.convertType(yield->getOperand(i).getType()); - (void)operandType; - assert(continuationBlockPHITypes[i] == operandType && - "values of mismatching types yielded from the region"); + + if (!isLoopWrapper) { + bool operandsProcessed = false; + for (Block &bb : region.getBlocks()) { + if (omp::YieldOp yield = dyn_cast(bb.getTerminator())) { + if (!operandsProcessed) { + for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { + continuationBlockPHITypes.push_back( + moduleTranslation.convertType(yield->getOperand(i).getType())); + } + operandsProcessed = true; + } else { + assert(continuationBlockPHITypes.size() == yield->getNumOperands() && + "mismatching number of values yielded from the region"); + for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) { + llvm::Type *operandType = + moduleTranslation.convertType(yield->getOperand(i).getType()); + (void)operandType; + assert(continuationBlockPHITypes[i] == operandType && + "values of mismatching types yielded from the region"); + } } + numYields++; } - numYields++; } } @@ -468,6 +474,13 @@ static llvm::Expected convertOmpOpRegions( moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) return llvm::make_error(); + // Create a direct branch here for loop wrappers to prevent their lack of a + // terminator from causing a crash below. + if (isLoopWrapper) { + builder.CreateBr(continuationBlock); + continue; + } + // Special handling for `omp.yield` and `omp.terminator` (we may have more // than one): they return the control to the parent OpenMP dialect operation // so replace them with the branch to the continuation block. We handle this