diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 3429008b2f241..3e7a0cca31c77 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -72,15 +72,50 @@ using namespace mlir::dataflow; namespace { +// Set of structures below to be filled with operations and arguments to erase. +// This is done to separate analysis and tree modification phases, +// otherwise analysis is operating on half-deleted tree which is incorrect. + +struct FunctionToCleanUp { + FunctionOpInterface funcOp; + BitVector nonLiveArgs; + BitVector nonLiveRets; +}; + +struct OperationToCleanup { + Operation *op; + BitVector nonLive; +}; + +struct BlockArgsToCleanup { + Block *b; + BitVector nonLiveArgs; +}; + +struct SuccessorOperandsToCleanup { + BranchOpInterface branch; + unsigned successorIndex; + BitVector nonLiveOperands; +}; + +struct RDVFinalCleanupList { + SmallVector operations; + SmallVector values; + SmallVector functions; + SmallVector operands; + SmallVector results; + SmallVector blocks; + SmallVector successorOperands; +}; + // Some helper functions... /// Return true iff at least one value in `values` is live, given the liveness /// information in `la`. -static bool hasLive(ValueRange values, RunLivenessAnalysis &la) { +static bool hasLive(ValueRange values, const DenseSet &nonLiveSet, + RunLivenessAnalysis &la) { for (Value value : values) { - // If there is a null value, it implies that it was dropped during the - // execution of this pass, implying that it was non-live. - if (!value) + if (nonLiveSet.contains(value)) continue; const Liveness *liveness = la.getLiveness(value); @@ -92,11 +127,12 @@ static bool hasLive(ValueRange values, RunLivenessAnalysis &la) { /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the /// i-th value in `values` is live, given the liveness information in `la`. -static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) { +static BitVector markLives(ValueRange values, const DenseSet &nonLiveSet, + RunLivenessAnalysis &la) { BitVector lives(values.size(), true); for (auto [index, value] : llvm::enumerate(values)) { - if (!value) { + if (nonLiveSet.contains(value)) { lives.reset(index); continue; } @@ -115,6 +151,18 @@ static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) { return lives; } +/// Collects values marked as "non-live" in the provided range and inserts them +/// into the nonLiveSet. A value is considered "non-live" if the corresponding +/// index in the `nonLive` bit vector is set. +static void collectNonLiveValues(DenseSet &nonLiveSet, ValueRange range, + const BitVector &nonLive) { + for (auto [index, result] : llvm::enumerate(range)) { + if (!nonLive[index]) + continue; + nonLiveSet.insert(result); + } +} + /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i] /// is 1. static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { @@ -165,52 +213,59 @@ static SmallVector operandsToOpOperands(OperandRange operands) { return opOperands; } -/// Clean a simple op `op`, given the liveness analysis information in `la`. -/// Here, cleaning means: -/// (1) Dropping all its uses, AND -/// (2) Erasing it -/// iff it has no memory effects and none of its results are live. +/// Process a simple operation `op` using the liveness analysis `la`. +/// If the operation has no memory effects and none of its results are live: +/// 1. Add the operation to a list for future removal, and +/// 2. Mark all its results as non-live values /// -/// It is assumed that `op` is simple. Here, a simple op is one which isn't a -/// function-like op, a call-like op, a region branch op, a branch op, a region -/// branch terminator op, or return-like. -static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) { - if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la)) +/// The operation `op` is assumed to be simple. A simple operation is one that +/// is NOT: +/// - Function-like +/// - Call-like +/// - A region branch operation +/// - A branch operation +/// - A region branch terminator +/// - Return-like +static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, + DenseSet &nonLiveSet, + RDVFinalCleanupList &cl) { + if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) return; - op->dropAllUses(); - op->erase(); + cl.operations.push_back(op); + collectNonLiveValues(nonLiveSet, op->getResults(), + BitVector(op->getNumResults(), true)); } -/// Clean a function-like op `funcOp`, given the liveness information in `la` -/// and the IR in `module`. Here, cleaning means: -/// (1) Dropping the uses of its unnecessary (non-live) arguments, -/// (2) Erasing these arguments, -/// (3) Erasing their corresponding operands from its callers, -/// (4) Erasing its unnecessary terminator operands (return values that are -/// non-live across all callers), -/// (5) Dropping the uses of these return values from its callers, AND -/// (6) Erasing these return values -/// iff it is not public or external. -static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module, - RunLivenessAnalysis &la) { +/// Process a function-like operation `funcOp` using the liveness analysis `la` +/// and the IR in `module`. If it is not public or external: +/// (1) Adding its non-live arguments to a list for future removal. +/// (2) Marking their corresponding operands in its callers for removal. +/// (3) Identifying and enqueueing unnecessary terminator operands +/// (return values that are non-live across all callers) for removal. +/// (4) Enqueueing the non-live arguments and return values for removal. +/// (5) Collecting the uses of these return values in its callers for future +/// removal. +/// (6) Marking all its results as non-live values. +static void processFuncOp(FunctionOpInterface funcOp, Operation *module, + RunLivenessAnalysis &la, DenseSet &nonLiveSet, + RDVFinalCleanupList &cl) { if (funcOp.isPublic() || funcOp.isExternal()) return; // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`. SmallVector arguments(funcOp.getArguments()); - BitVector nonLiveArgs = markLives(arguments, la); + BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la); nonLiveArgs = nonLiveArgs.flip(); // Do (1). for (auto [index, arg] : llvm::enumerate(arguments)) - if (arg && nonLiveArgs[index]) - arg.dropAllUses(); + if (arg && nonLiveArgs[index]) { + cl.values.push_back(arg); + nonLiveSet.insert(arg); + } // Do (2). - funcOp.eraseArguments(nonLiveArgs); - - // Do (3). SymbolTable::UseRange uses = *funcOp.getSymbolUses(module); for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); @@ -222,9 +277,10 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module, operandsToOpOperands(cast(callOp).getArgOperands()); for (int index : nonLiveArgs.set_bits()) nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber()); - callOp->eraseOperands(nonLiveCallOperands); + cl.operands.push_back({callOp, nonLiveCallOperands}); } + // Do (3). // Get the list of unnecessary terminator operands (return values that are // non-live across all callers) in `nonLiveRets`. There is a very important // subtlety here. Unnecessary terminator operands are NOT the operands of the @@ -253,62 +309,74 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module, for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); assert(isa(callOp) && "expected a call-like user"); - BitVector liveCallRets = markLives(callOp->getResults(), la); + BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la); nonLiveRets &= liveCallRets.flip(); } - // Do (4). // Note that in the absence of control flow ops forcing the control to go from // the entry (first) block to the other blocks, the control never reaches any // block other than the entry block, because every block has a terminator. for (Block &block : funcOp.getBlocks()) { Operation *returnOp = block.getTerminator(); if (returnOp && returnOp->getNumOperands() == numReturns) - returnOp->eraseOperands(nonLiveRets); + cl.operands.push_back({returnOp, nonLiveRets}); } - funcOp.eraseResults(nonLiveRets); + + // Do (4). + cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets}); // Do (5) and (6). for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); assert(isa(callOp) && "expected a call-like user"); - dropUsesAndEraseResults(callOp, nonLiveRets); + cl.results.push_back({callOp, nonLiveRets}); + collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets); } } -/// Clean a region branch op `regionBranchOp`, given the liveness information in -/// `la`. Here, cleaning means: -/// (1') Dropping all its uses, AND -/// (2') Erasing it -/// if it has no memory effects and none of its results are live, AND -/// (1) Erasing its unnecessary operands (operands that are forwarded to -/// unneccesary results and arguments), -/// (2) Cleaning each of its regions, -/// (3) Dropping the uses of its unnecessary results (results that are -/// forwarded from unnecessary operands and terminator operands), AND -/// (4) Erasing these results -/// otherwise. -/// Note that here, cleaning a region means: -/// (2.a) Dropping the uses of its unnecessary arguments (arguments that are -/// forwarded from unneccesary operands and terminator operands), -/// (2.b) Erasing these arguments, AND -/// (2.c) Erasing its unnecessary terminator operands (terminator operands -/// that are forwarded to unneccesary results and arguments). -/// It is important to note that values in this op flow from operands and -/// terminator operands (successor operands) to arguments and results (successor -/// inputs). -static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, - RunLivenessAnalysis &la) { +/// Process a region branch operation `regionBranchOp` using the liveness +/// information in `la`. The processing involves two scenarios: +/// +/// Scenario 1: If the operation has no memory effects and none of its results +/// are live: +/// (1') Enqueue all its uses for deletion. +/// (2') Enqueue the branch itself for deletion. +/// +/// Scenario 2: Otherwise: +/// (1) Collect its unnecessary operands (operands forwarded to unnecessary +/// results or arguments). +/// (2) Process each of its regions. +/// (3) Collect the uses of its unnecessary results (results forwarded from +/// unnecessary operands +/// or terminator operands). +/// (4) Add these results to the deletion list. +/// +/// Processing a region includes: +/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded +/// from unnecessary operands +/// or terminator operands). +/// (b) Collecting these unnecessary arguments. +/// (c) Collecting its unnecessary terminator operands (terminator operands +/// forwarded to unnecessary results +/// or arguments). +/// +/// Value Flow Note: In this operation, values flow as follows: +/// - From operands and terminator operands (successor operands) +/// - To arguments and results (successor inputs). +static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, + RunLivenessAnalysis &la, + DenseSet &nonLiveSet, + RDVFinalCleanupList &cl) { // Mark live results of `regionBranchOp` in `liveResults`. auto markLiveResults = [&](BitVector &liveResults) { - liveResults = markLives(regionBranchOp->getResults(), la); + liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); }; // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. auto markLiveArgs = [&](DenseMap &liveArgs) { for (Region ®ion : regionBranchOp->getRegions()) { SmallVector arguments(region.front().getArguments()); - BitVector regionLiveArgs = markLives(arguments, la); + BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la); liveArgs[®ion] = regionLiveArgs; } }; @@ -491,18 +559,19 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, } }; - // Do (1') and (2'). This is the only case where the entire `regionBranchOp` + // Scenario 1. This is the only case where the entire `regionBranchOp` // is removed. It will not happen in any other scenario. Note that in this // case, a non-forwarded operand of `regionBranchOp` could be live/non-live. // It could never be live because of this op but its liveness could have been // attributed to something else. + // Do (1') and (2'). if (isMemoryEffectFree(regionBranchOp.getOperation()) && - !hasLive(regionBranchOp->getResults(), la)) { - regionBranchOp->dropAllUses(); - regionBranchOp->erase(); + !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) { + cl.operations.push_back(regionBranchOp.getOperation()); return; } + // Scenario 2. // At this point, we know that every non-forwarded operand of `regionBranchOp` // is live. @@ -538,48 +607,49 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, terminatorOperandsToKeep); // Do (1). - regionBranchOp->eraseOperands(operandsToKeep.flip()); + cl.operands.push_back({regionBranchOp, operandsToKeep.flip()}); // Do (2.a) and (2.b). for (Region ®ion : regionBranchOp->getRegions()) { assert(!region.empty() && "expected a non-empty region in an op " "implementing `RegionBranchOpInterface`"); - for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) { - if (argsToKeep[®ion][index]) - continue; - if (arg) - arg.dropAllUses(); - } - region.front().eraseArguments(argsToKeep[®ion].flip()); + BitVector argsToRemove = argsToKeep[®ion].flip(); + cl.blocks.push_back({®ion.front(), argsToRemove}); + collectNonLiveValues(nonLiveSet, region.front().getArguments(), + argsToRemove); } // Do (2.c). for (Region ®ion : regionBranchOp->getRegions()) { Operation *terminator = region.front().getTerminator(); - terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip()); + cl.operands.push_back( + {terminator, terminatorOperandsToKeep[terminator].flip()}); } // Do (3) and (4). - dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip()); + BitVector resultsToRemove = resultsToKeep.flip(); + collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(), + resultsToRemove); + cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove}); } -// 1. Iterate over each successor block of the given BranchOpInterface -// operation. -// 2. For each successor block: -// a. Retrieve the operands passed to the successor. -// b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine -// which operands are live in the successor block. -// c. Mark each operand as live or dead based on the analysis. -// 3. Remove dead operands from the branch operation and arguments accordingly - -static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) { +/// Steps to process a `BranchOpInterface` operation: +/// Iterate through each successor block of `branchOp`. +/// (1) For each successor block, gather all operands from all successors. +/// (2) Fetch their associated liveness analysis data and collect for future +/// removal. +/// (3) Identify and collect the dead operands from the successor block +/// as well as their corresponding arguments. + +static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, + DenseSet &nonLiveSet, + RDVFinalCleanupList &cl) { unsigned numSuccessors = branchOp->getNumSuccessors(); - // Do (1) for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { Block *successorBlock = branchOp->getSuccessor(succIdx); - // Do (2) + // Do (1) SuccessorOperands successorOperands = branchOp.getSuccessorOperands(succIdx); SmallVector operandValues; @@ -588,22 +658,74 @@ static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) { operandValues.push_back(successorOperands[operandIdx]); } - BitVector successorLiveOperands = markLives(operandValues, la); + // Do (2) + BitVector successorNonLive = + markLives(operandValues, nonLiveSet, la).flip(); + collectNonLiveValues(nonLiveSet, successorBlock->getArguments(), + successorNonLive); // Do (3) - for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) { - if (!successorLiveOperands[argIdx]) { - if (successorBlock->getNumArguments() < successorOperands.size()) { - // if block was cleaned through a different code path - // we only need to remove operands from the invokation - successorOperands.erase(argIdx); - continue; - } + cl.blocks.push_back({successorBlock, successorNonLive}); + cl.successorOperands.push_back({branchOp, succIdx, successorNonLive}); + } +} + +/// Removes dead values collected in RDVFinalCleanupList. +/// To be run once when all dead values have been collected. +static void cleanUpDeadVals(RDVFinalCleanupList &list) { + // 1. Operations + for (auto &op : list.operations) { + op->dropAllUses(); + op->erase(); + } + + // 2. Values + for (auto &v : list.values) { + v.dropAllUses(); + } + + // 3. Functions + for (auto &f : list.functions) { + f.funcOp.eraseArguments(f.nonLiveArgs); + f.funcOp.eraseResults(f.nonLiveRets); + } + + // 4. Operands + for (auto &o : list.operands) { + o.op->eraseOperands(o.nonLive); + } - successorBlock->getArgument(argIdx).dropAllUses(); - successorOperands.erase(argIdx); - successorBlock->eraseArgument(argIdx); - } + // 5. Results + for (auto &r : list.results) { + dropUsesAndEraseResults(r.op, r.nonLive); + } + + // 6. Blocks + for (auto &b : list.blocks) { + // blocks that are accessed via multiple codepaths processed once + if (b.b->getNumArguments() != b.nonLiveArgs.size()) + continue; + // it iterates backwards because erase invalidates all successor indexes + for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { + if (!b.nonLiveArgs[i]) + continue; + b.b->getArgument(i).dropAllUses(); + b.b->eraseArgument(i); + } + } + + // 7. Successor Operands + for (auto &op : list.successorOperands) { + SuccessorOperands successorOperands = + op.branch.getSuccessorOperands(op.successorIndex); + // blocks that are accessed via multiple codepaths processed once + if (successorOperands.size() != op.nonLiveOperands.size()) + continue; + // it iterates backwards because erase invalidates all successor indexes + for (int i = successorOperands.size() - 1; i >= 0; --i) { + if (!op.nonLiveOperands[i]) + continue; + successorOperands.erase(i); } } } @@ -617,13 +739,21 @@ void RemoveDeadValues::runOnOperation() { auto &la = getAnalysis(); Operation *module = getOperation(); + // Tracks values eligible for erasure - complements liveness analysis to + // identify "droppable" values. + DenseSet deadVals; + + // Maintains a list of Ops, values, branches, etc., slated for cleanup at the + // end of this pass. + RDVFinalCleanupList finalCleanupList; + module->walk([&](Operation *op) { if (auto funcOp = dyn_cast(op)) { - cleanFuncOp(funcOp, module, la); + processFuncOp(funcOp, module, la, deadVals, finalCleanupList); } else if (auto regionBranchOp = dyn_cast(op)) { - cleanRegionBranchOp(regionBranchOp, la); + processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList); } else if (auto branchOp = dyn_cast(op)) { - cleanBranchOp(branchOp, la); + processBranchOp(branchOp, la, deadVals, finalCleanupList); } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) { // Nothing to do here because this is a terminator op and it should be // honored with respect to its parent @@ -631,9 +761,11 @@ void RemoveDeadValues::runOnOperation() { // Nothing to do because this op is associated with a function op and gets // cleaned when the latter is cleaned. } else { - cleanSimpleOp(op, la); + processSimpleOp(op, la, deadVals, finalCleanupList); } }); + + cleanUpDeadVals(finalCleanupList); } std::unique_ptr mlir::createRemoveDeadValuesPass() { diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 9273ac01e7cce..fe7bcbc7c490b 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -73,6 +73,32 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { // ----- +// Checking that the arguments of linalg.generic are properly handled +// All code below is removed as a result of the pass +// +#map = affine_map<(d0, d1, d2) -> (0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +module { + func.func @main() { + %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32> + %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32> + // CHECK-NOT: arith.constant + %0 = tensor.empty() : tensor<1x25x13xi32> + // CHECK-NOT: tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_3, %cst_7 : tensor<1x25x13xi32>, tensor<1x25x13xi32>) outs(%0 : tensor<1x25x13xi32>) { + // CHECK-NOT: linalg.generic + ^bb0(%in: i32, %in_15: i32, %out: i32): + %29 = arith.xori %in, %in_15 : i32 + // CHECK-NOT: arith.xori + linalg.yield %29 : i32 + // CHECK-NOT: linalg.yield + } -> tensor<1x25x13xi32> + return + } +} + +// ----- + // Note that this cleanup cannot be done by the `canonicalize` pass. // // CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {