From f40a969a5fd1ea4d9cfa6516fcc40ff611f3a369 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 14 Sep 2023 09:25:38 +0200 Subject: [PATCH] [mlir][bufferization] Empty tensor elimination based on SubsetOpInterface This commit generalizes empty tensor elimination to operate on subset ops. No new test cases are added because all current subset ops were already supported by previously. From this perspective, this change is NFC. A new interface method (and a helper method) are added to `SubsetOpInterface` to build the subset of the destination tensor. --- .../IR/SubsetInsertionOpInterface.td | 36 ++++ .../TransformOps/BufferizationTransformOps.td | 20 +- .../Bufferization/Transforms/Passes.td | 21 +- .../Bufferization/Transforms/Transforms.h | 38 ++-- .../BufferizationTransformOps.cpp | 3 +- .../Transforms/EmptyTensorElimination.cpp | 195 +++++------------- .../SubsetInsertionOpInterfaceImpl.cpp | 49 +++++ 7 files changed, 180 insertions(+), 182 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td index edf6525377957..aa09354bc753d 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td @@ -99,6 +99,42 @@ def SubsetInsertionOpInterface : OpInterface<"SubsetInsertionOpInterface"> { "::mlir::Value":$candidate, "::llvm::function_ref":$equivalenceFn) >, + InterfaceMethod< + /*desc=*/[{ + Return the subset of the destination tensor that this operation + inserts into. + + Example: + ``` + // SubsetOpInterface op: + %0 = tensor.insert_slice %t0 into %t1[%pos][5][1] + : tensor<5xf32> into tensor + // Subset (built by this function): + %1 = tensor.extract_slice %t1[%pos][5][1] + : tensor to tensor<5xf32> + ``` + + Note: Implementations do not necessarily have to build new IR. They + may return existing SSA values. + }], + /*retType=*/"::mlir::Value", + /*methodName=*/"buildSubsetExtraction", + /*args=*/(ins "::mlir::OpBuilder &":$builder, "Location":$loc) + >, + InterfaceMethod< + /*desc=*/[{ + Return all SSA values that are needed (i.e., must be in scope) at the + insertion of the builder when calling `buildSubsetExtraction`. Users + of `buildSubsetExtraction` can use this helper method to find a + suitable insertion point. + + Example: The SSA values needed to build the subset in the example of + `buildSubsetExtraction` are %t1 and %pos. + }], + /*retType=*/"::llvm::SmallVector<::mlir::Value>", + /*methodName=*/"getValuesNeededToBuildSubsetExtraction", + /*args=*/(ins) + >, ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td index 46a95ad8db2a6..84bd047e6d51e 100644 --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -109,19 +109,17 @@ def EliminateEmptyTensorsOp DeclareOpInterfaceMethods]> { let description = [{ Try to eliminate all `tensor.empty` ops within the targeted op by replacing - them with a destination tensor. + them with another destination tensor. - `tensor.empty` ops cannot be bufferizes. They can either be converted to - `bufferization.alloc_tensor` or replaced with another tensor (via this - transform). `tensor.empty` does not specify the contents of the returned + "tensor.empty" ops cannot be bufferized. They can either be converted to + "bufferization.alloc_tensor" or replaced with another tensor (via this + transform). "tensor.empty" does not specify the contents of the returned tensor so their results can be replaced with arbitrary tensor values as long as the dimensions match. - This transform looks for `tensor.empty` ops where the SSA use-def chain of - the result ends in a supported "anchor op" (always following the aliasing - OpOperand/OpResult chain). Currently supported anchor ops are: - - `tensor.insert_slice` - - `bufferization.yield` (inside `bufferization.alloc_tensor`) + This transformation looks for subset ops that insert a tensor that + originates from a "tensor.empty" (as per the reverse use-def chain). Such + "tensor.empty" ops are replaced with the destination subset. Example: @@ -138,6 +136,10 @@ def EliminateEmptyTensorsOp %2 = tensor.insert_slice %1 into %t[1][5][1] ``` + In the above example, the subset op is "tensor.insert_slice". When tracing + back the reverse use-def chain of a the source, we end up at a + "tensor.empty" op. + The above example can bufferize without an allocation (in the absence of other conflicts) because there is no longer a `tensor.empty` op. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index df9bfcbfc5488..ff43cff817b64 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -402,11 +402,22 @@ def PromoteBuffersToStack : Pass<"promote-buffers-to-stack", "func::FuncOp"> { def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> { let summary = "Try to eliminate all tensor.empty ops."; let description = [{ - This pass tries to eliminate all insert_slice op-anchored tensor.empty ops. - I.e., when a value that is equivalent to an tensor.empty op is inserted into - another tensor, this pass tries to rewrite the IR in such a way that the - destination tensor of the insert_slice op is used directly instead of the - tensor.empty result. + Try to eliminate "tensor.empty" ops inside `op`. This transformation looks + for subset ops that insert a tensor that originates from a "tensor.empty" + (as per the reverse use-def chain). Such "tensor.empty" ops are replaced + with the destination subset. + + E.g.: + ``` + %0 = tensor.empty() : tensor<10xf32> + %1 = linalg.fill ... outs(%0 : tensor<10xf32>) + %2 = tensor.insert_slice %0 into %t ... + ``` + + In the above example, the subset op is "tensor.insert_slice". When tracing + back the reverse use-def chain of a the source, we end up at a + "tensor.empty" op. The "tensor.empty" op is replaced with a + "tensor.extract_slice" op. }]; let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()"; } diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h index a0cfc811a0b50..df866daf1ab1f 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -19,38 +19,26 @@ struct BufferizationStatistics; class OneShotAnalysisState; struct OneShotBufferizationOptions; -/// A function that matches anchor OpOperands for tensor::EmptyOp elimination. -/// If an OpOperand is matched, the function should populate the SmallVector -/// with all values that are needed during `RewriteFn` to produce the -/// replacement value. -using AnchorMatchFn = std::function &)>; - -/// A function that rewrites matched anchors. -using RewriteFn = std::function; - -/// Try to eliminate tensor::EmptyOps inside `op`. +/// Try to eliminate "tensor.empty" ops inside `op`. This transformation looks +/// for subset ops that insert a tensor that originates from a "tensor.empty" +/// (as per the reverse use-def chain). Such "tensor.empty" ops are replaced +/// with the destination subset. /// -/// * `rewriteFunc` generates the replacement for the tensor::EmptyOp. -/// * Only tensor::EmptyOps that are anchored on a matching OpOperand as per -/// `anchorMatchFunc` are considered. "Anchored" means that there is a path -/// on the reverse SSA use-def chain, starting from the OpOperand and always -/// following the aliasing OpOperand, that eventually ends at a single -/// tensor::EmptyOp. +/// E.g.: +/// %0 = tensor.empty() : tensor<10xf32> +/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>) +/// %2 = tensor.insert_slice %0 into %t ... +/// +/// In the above example, the subset op is "tensor.insert_slice". When tracing +/// back the reverse use-def chain of a the source, we end up at a +/// "tensor.empty" op. LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, - OneShotAnalysisState &state, - AnchorMatchFn anchorMatchFunc, - RewriteFn rewriteFunc); + OneShotAnalysisState &state); /// Within the given operation, hoist buffers from loops where possible. See /// "BufferLoopHoistingPass" for more information. void hoistBuffersFromLoops(Operation *op); -/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an -/// InsertSliceOp, i.e., if it is eventually inserted into another tensor -/// (and some other conditions are met). -LogicalResult insertSliceAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state); - /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops. /// After applying this transform, the IR can be bufferized without inserting /// additional buffer allocations. diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index 097f75a7bc50f..b84cc452d0141 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -121,8 +121,7 @@ DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( if (failed(analyzeOp(target, state))) return mlir::emitSilenceableFailure(target->getLoc()) << "failed to analyze op"; - if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( - rewriter, target, state))) + if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state))) return mlir::emitSilenceableFailure(target->getLoc()) << "failed to eliminate insert_slice anchored tensor.empty ops"; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 4e0781dae0c25..1662b52968d35 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -99,154 +100,67 @@ findValidInsertionPoint(Operation *emptyTensorOp, return nullptr; } -/// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced -/// with the result of `rewriteFunc` if it is anchored on a matching -/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def -/// chain, starting from the OpOperand and always following the aliasing -/// OpOperand, that eventually ends at the tensor::EmptyOp. -/// -/// E.g.: -/// %0 = tensor.empty() : tensor<10xf32> -/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>) -/// %2 = tensor.insert_slice %0 into %t ... -/// -/// In the above example, the anchor is the source operand of the insert_slice -/// op. When tracing back the reverse use-def chain, we end up at a -/// tensor.empty op. LogicalResult mlir::bufferization::eliminateEmptyTensors( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, - AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) { + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { OpBuilder::InsertionGuard g(rewriter); - op->walk([&](Operation *op) { - for (OpOperand &operand : op->getOpOperands()) { - // Skip operands that do not bufferize inplace. - if (!state.isInPlace(operand)) - continue; - // All values that are needed to create the replacement op. - SmallVector neededValues; - // Is this an anchor? - if (!anchorMatchFunc(operand, neededValues)) + op->walk([&](SubsetInsertionOpInterface op) { + OpOperand &source = op.getSourceOperand(); + // Skip operands that do not bufferize inplace. "tensor.empty" could still + // be replaced, but the transformation may not be beneficial. + if (!state.isInPlace(source)) + return WalkResult::skip(); + // All values that are needed to create the replacement op. + SmallVector neededValues = + op.getValuesNeededToBuildSubsetExtraction(); + + // Find tensor.empty ops on the reverse SSA use-def chain. Only follow + // equivalent tensors. I.e., stop when there are ops such as extract_slice + // on the path. + TraversalConfig config; + config.followEquivalentOnly = true; + config.alwaysIncludeLeaves = false; + // Replace only if the types match or are static <-> dynamic casts. We do + // not support slices or reshapes. + // TODO: This could be extended to support IR such as: + // %0 = tensor.empty() : tensor<128xf32> + // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) + // %2 = tensor.expand_shape %1 ... + // %3 = tensor.insert_slice %2 into ... + config.followSameTypeOrCastsOnly = true; + SetVector emptyTensors = state.findValueInReverseUseDefChain( + source.get(), /*condition=*/ + [&](Value val) { return val.getDefiningOp(); }, + config); + + for (Value v : emptyTensors) { + Operation *emptyTensorOp = v.getDefiningOp(); + + // Find a suitable insertion point. If no suitable insertion point for + // the replacement can be found, skip this replacement. + Operation *insertionPoint = + findValidInsertionPoint(emptyTensorOp, neededValues); + if (!insertionPoint) continue; - // Find tensor.empty ops on the reverse SSA use-def chain. Only follow - // equivalent tensors. I.e., stop when there are ops such as extract_slice - // on the path. - TraversalConfig config; - config.followEquivalentOnly = true; - config.alwaysIncludeLeaves = false; - // Replace only if the types match or are static <-> dynamic casts. We do - // not support slices or reshapes. - // TODO: This could be extended to support IR such as: - // %0 = tensor.empty() : tensor<128xf32> - // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) - // %2 = tensor.expand_shape %1 ... - // %3 = tensor.insert_slice %2 into ... - config.followSameTypeOrCastsOnly = true; - SetVector emptyTensors = state.findValueInReverseUseDefChain( - operand.get(), /*condition=*/ - [&](Value val) { return val.getDefiningOp(); }, - config); - - for (Value v : emptyTensors) { - Operation *emptyTensorOp = v.getDefiningOp(); - - // Find a suitable insertion point. If no suitable insertion point for - // the replacement can be found, skip this replacement. - Operation *insertionPoint = - findValidInsertionPoint(emptyTensorOp, neededValues); - if (!insertionPoint) - continue; - - rewriter.setInsertionPoint(insertionPoint); - Value replacement = - rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand); - if (!replacement) - continue; - if (replacement.getType() != v.getType()) { - rewriter.setInsertionPointAfterValue(replacement); - replacement = rewriter.create(v.getLoc(), v.getType(), - replacement); - } - // Replace the tensor::EmptyOp. - rewriter.replaceOp(emptyTensorOp, replacement); - state.resetCache(); + rewriter.setInsertionPoint(insertionPoint); + Value replacement = + op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + if (!replacement) + continue; + if (replacement.getType() != v.getType()) { + rewriter.setInsertionPointAfterValue(replacement); + replacement = rewriter.create(v.getLoc(), v.getType(), + replacement); } + // Replace the tensor::EmptyOp. + rewriter.replaceOp(emptyTensorOp, replacement); + state.resetCache(); } - }); - - return success(); -} - -/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be -/// eliminated if it is eventually inserted into another tensor (and some other -/// conditions are met). -/// -/// E.g.: -/// %0 = tensor.empty() -/// %1 = linalg.fill(%cst, %0) {inplace = [true]} -/// %2 = tensor.insert_slice %1 into %t[10][20][1] -/// -/// tensor::EmptyOp elimination will try to fill %t inplace instead of filling a -/// new allocation %0 and inserting it into %t. This is done by replacing the -/// tensor::EmptyOp with: -/// -/// %0 = tensor.extract_slice %t[10][20][1] -/// -/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets -/// those bufferize inplace in the absence of other conflicts. -/// -/// Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert -/// source's reverse use-def chain is eliminated if: -/// * On the reverse use-def chain path from the InsertSliceOp to the -/// tensor::EmptyOp, all ops were decided to bufferize inplace and the buffer -/// relation is "equivalent" (TODO: can be relaxed if needed). -/// * The reverse use-def chain has exactly one end, which is the -/// tensor::EmptyOp. -template -static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { - return eliminateEmptyTensors( - rewriter, op, state, - /*anchorMatchFunc=*/ - [&](OpOperand &operand, SmallVector &neededValues) { - auto insertSliceOp = dyn_cast(operand.getOwner()); - if (!insertSliceOp) - return false; - if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) - return false; - // Collect all values that are needed to construct the replacement op. - neededValues.append(insertSliceOp.getOffsets().begin(), - insertSliceOp.getOffsets().end()); - neededValues.append(insertSliceOp.getSizes().begin(), - insertSliceOp.getSizes().end()); - neededValues.append(insertSliceOp.getStrides().begin(), - insertSliceOp.getStrides().end()); - neededValues.push_back(insertSliceOp.getDest()); - - return true; - }, - /*rewriteFunc=*/ - [](OpBuilder &b, Location loc, OpOperand &operand) { - auto insertOp = cast(operand.getOwner()); - auto extractOp = b.create( - loc, insertOp.getSourceType(), insertOp.getDest(), - insertOp.getMixedOffsets(), insertOp.getMixedSizes(), - insertOp.getMixedStrides()); - return extractOp.getResult(); - }); -} + return WalkResult::advance(); + }); -LogicalResult -mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { - if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< - tensor::InsertSliceOp>(rewriter, op, state))) - return failure(); - if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< - tensor::ParallelInsertSliceOp>(rewriter, op, state))) - return failure(); return success(); } @@ -276,8 +190,7 @@ void EmptyTensorElimination::runOnOperation() { } IRRewriter rewriter(op->getContext()); - if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( - rewriter, op, state))) + if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp index 1156f2501a96e..dff9f64169d49 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp @@ -34,6 +34,31 @@ bool isSubsetEquivalentToInsertSliceLikeOp( isEqualConstantIntOrValue); } +template +Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc, + OpTy insertSliceOp) { + auto extractOp = b.create( + loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(), + insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), + insertSliceOp.getMixedStrides()); + return extractOp.getResult(); +} + +template +SmallVector +getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) { + SmallVector neededValues; + // Collect all values that are needed to construct the replacement op. + neededValues.append(insertSliceOp.getOffsets().begin(), + insertSliceOp.getOffsets().end()); + neededValues.append(insertSliceOp.getSizes().begin(), + insertSliceOp.getSizes().end()); + neededValues.append(insertSliceOp.getStrides().begin(), + insertSliceOp.getStrides().end()); + neededValues.push_back(insertSliceOp.getDest()); + return neededValues; +} + struct InsertSliceOpInterface : public SubsetInsertionOpInterface::ExternalModel { @@ -48,6 +73,18 @@ struct InsertSliceOpInterface return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate, equivalenceFn); } + + Value buildSubsetExtraction(Operation *op, OpBuilder &builder, + Location loc) const { + return buildSubsetExtractionOfInsertSliceLikeOp( + builder, loc, cast(op)); + } + + SmallVector + getValuesNeededToBuildSubsetExtraction(Operation *op) const { + return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp( + cast(op)); + } }; struct ParallelInsertSliceOpInterface @@ -68,6 +105,18 @@ struct ParallelInsertSliceOpInterface return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate, equivalenceFn); } + + Value buildSubsetExtraction(Operation *op, OpBuilder &builder, + Location loc) const { + return buildSubsetExtractionOfInsertSliceLikeOp( + builder, loc, cast(op)); + } + + SmallVector + getValuesNeededToBuildSubsetExtraction(Operation *op) const { + return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp( + cast(op)); + } }; } // namespace