diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td index edf6525377957..aa09354bc753d 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td @@ -99,6 +99,42 @@ def SubsetInsertionOpInterface : OpInterface<"SubsetInsertionOpInterface"> { "::mlir::Value":$candidate, "::llvm::function_ref":$equivalenceFn) >, + InterfaceMethod< + /*desc=*/[{ + Return the subset of the destination tensor that this operation + inserts into. + + Example: + ``` + // SubsetOpInterface op: + %0 = tensor.insert_slice %t0 into %t1[%pos][5][1] + : tensor<5xf32> into tensor + // Subset (built by this function): + %1 = tensor.extract_slice %t1[%pos][5][1] + : tensor to tensor<5xf32> + ``` + + Note: Implementations do not necessarily have to build new IR. They + may return existing SSA values. + }], + /*retType=*/"::mlir::Value", + /*methodName=*/"buildSubsetExtraction", + /*args=*/(ins "::mlir::OpBuilder &":$builder, "Location":$loc) + >, + InterfaceMethod< + /*desc=*/[{ + Return all SSA values that are needed (i.e., must be in scope) at the + insertion of the builder when calling `buildSubsetExtraction`. Users + of `buildSubsetExtraction` can use this helper method to find a + suitable insertion point. + + Example: The SSA values needed to build the subset in the example of + `buildSubsetExtraction` are %t1 and %pos. + }], + /*retType=*/"::llvm::SmallVector<::mlir::Value>", + /*methodName=*/"getValuesNeededToBuildSubsetExtraction", + /*args=*/(ins) + >, ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td index 46a95ad8db2a6..84bd047e6d51e 100644 --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -109,19 +109,17 @@ def EliminateEmptyTensorsOp DeclareOpInterfaceMethods]> { let description = [{ Try to eliminate all `tensor.empty` ops within the targeted op by replacing - them with a destination tensor. + them with another destination tensor. - `tensor.empty` ops cannot be bufferizes. They can either be converted to - `bufferization.alloc_tensor` or replaced with another tensor (via this - transform). `tensor.empty` does not specify the contents of the returned + "tensor.empty" ops cannot be bufferized. They can either be converted to + "bufferization.alloc_tensor" or replaced with another tensor (via this + transform). "tensor.empty" does not specify the contents of the returned tensor so their results can be replaced with arbitrary tensor values as long as the dimensions match. - This transform looks for `tensor.empty` ops where the SSA use-def chain of - the result ends in a supported "anchor op" (always following the aliasing - OpOperand/OpResult chain). Currently supported anchor ops are: - - `tensor.insert_slice` - - `bufferization.yield` (inside `bufferization.alloc_tensor`) + This transformation looks for subset ops that insert a tensor that + originates from a "tensor.empty" (as per the reverse use-def chain). Such + "tensor.empty" ops are replaced with the destination subset. Example: @@ -138,6 +136,10 @@ def EliminateEmptyTensorsOp %2 = tensor.insert_slice %1 into %t[1][5][1] ``` + In the above example, the subset op is "tensor.insert_slice". When tracing + back the reverse use-def chain of a the source, we end up at a + "tensor.empty" op. + The above example can bufferize without an allocation (in the absence of other conflicts) because there is no longer a `tensor.empty` op. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index df9bfcbfc5488..ff43cff817b64 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -402,11 +402,22 @@ def PromoteBuffersToStack : Pass<"promote-buffers-to-stack", "func::FuncOp"> { def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> { let summary = "Try to eliminate all tensor.empty ops."; let description = [{ - This pass tries to eliminate all insert_slice op-anchored tensor.empty ops. - I.e., when a value that is equivalent to an tensor.empty op is inserted into - another tensor, this pass tries to rewrite the IR in such a way that the - destination tensor of the insert_slice op is used directly instead of the - tensor.empty result. + Try to eliminate "tensor.empty" ops inside `op`. This transformation looks + for subset ops that insert a tensor that originates from a "tensor.empty" + (as per the reverse use-def chain). Such "tensor.empty" ops are replaced + with the destination subset. + + E.g.: + ``` + %0 = tensor.empty() : tensor<10xf32> + %1 = linalg.fill ... outs(%0 : tensor<10xf32>) + %2 = tensor.insert_slice %0 into %t ... + ``` + + In the above example, the subset op is "tensor.insert_slice". When tracing + back the reverse use-def chain of a the source, we end up at a + "tensor.empty" op. The "tensor.empty" op is replaced with a + "tensor.extract_slice" op. }]; let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()"; } diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h index a0cfc811a0b50..df866daf1ab1f 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -19,38 +19,26 @@ struct BufferizationStatistics; class OneShotAnalysisState; struct OneShotBufferizationOptions; -/// A function that matches anchor OpOperands for tensor::EmptyOp elimination. -/// If an OpOperand is matched, the function should populate the SmallVector -/// with all values that are needed during `RewriteFn` to produce the -/// replacement value. -using AnchorMatchFn = std::function &)>; - -/// A function that rewrites matched anchors. -using RewriteFn = std::function; - -/// Try to eliminate tensor::EmptyOps inside `op`. +/// Try to eliminate "tensor.empty" ops inside `op`. This transformation looks +/// for subset ops that insert a tensor that originates from a "tensor.empty" +/// (as per the reverse use-def chain). Such "tensor.empty" ops are replaced +/// with the destination subset. /// -/// * `rewriteFunc` generates the replacement for the tensor::EmptyOp. -/// * Only tensor::EmptyOps that are anchored on a matching OpOperand as per -/// `anchorMatchFunc` are considered. "Anchored" means that there is a path -/// on the reverse SSA use-def chain, starting from the OpOperand and always -/// following the aliasing OpOperand, that eventually ends at a single -/// tensor::EmptyOp. +/// E.g.: +/// %0 = tensor.empty() : tensor<10xf32> +/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>) +/// %2 = tensor.insert_slice %0 into %t ... +/// +/// In the above example, the subset op is "tensor.insert_slice". When tracing +/// back the reverse use-def chain of a the source, we end up at a +/// "tensor.empty" op. LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, - OneShotAnalysisState &state, - AnchorMatchFn anchorMatchFunc, - RewriteFn rewriteFunc); + OneShotAnalysisState &state); /// Within the given operation, hoist buffers from loops where possible. See /// "BufferLoopHoistingPass" for more information. void hoistBuffersFromLoops(Operation *op); -/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an -/// InsertSliceOp, i.e., if it is eventually inserted into another tensor -/// (and some other conditions are met). -LogicalResult insertSliceAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state); - /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops. /// After applying this transform, the IR can be bufferized without inserting /// additional buffer allocations. diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index 097f75a7bc50f..b84cc452d0141 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -121,8 +121,7 @@ DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( if (failed(analyzeOp(target, state))) return mlir::emitSilenceableFailure(target->getLoc()) << "failed to analyze op"; - if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( - rewriter, target, state))) + if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state))) return mlir::emitSilenceableFailure(target->getLoc()) << "failed to eliminate insert_slice anchored tensor.empty ops"; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 4e0781dae0c25..1662b52968d35 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -99,154 +100,67 @@ findValidInsertionPoint(Operation *emptyTensorOp, return nullptr; } -/// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced -/// with the result of `rewriteFunc` if it is anchored on a matching -/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def -/// chain, starting from the OpOperand and always following the aliasing -/// OpOperand, that eventually ends at the tensor::EmptyOp. -/// -/// E.g.: -/// %0 = tensor.empty() : tensor<10xf32> -/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>) -/// %2 = tensor.insert_slice %0 into %t ... -/// -/// In the above example, the anchor is the source operand of the insert_slice -/// op. When tracing back the reverse use-def chain, we end up at a -/// tensor.empty op. LogicalResult mlir::bufferization::eliminateEmptyTensors( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, - AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) { + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { OpBuilder::InsertionGuard g(rewriter); - op->walk([&](Operation *op) { - for (OpOperand &operand : op->getOpOperands()) { - // Skip operands that do not bufferize inplace. - if (!state.isInPlace(operand)) - continue; - // All values that are needed to create the replacement op. - SmallVector neededValues; - // Is this an anchor? - if (!anchorMatchFunc(operand, neededValues)) + op->walk([&](SubsetInsertionOpInterface op) { + OpOperand &source = op.getSourceOperand(); + // Skip operands that do not bufferize inplace. "tensor.empty" could still + // be replaced, but the transformation may not be beneficial. + if (!state.isInPlace(source)) + return WalkResult::skip(); + // All values that are needed to create the replacement op. + SmallVector neededValues = + op.getValuesNeededToBuildSubsetExtraction(); + + // Find tensor.empty ops on the reverse SSA use-def chain. Only follow + // equivalent tensors. I.e., stop when there are ops such as extract_slice + // on the path. + TraversalConfig config; + config.followEquivalentOnly = true; + config.alwaysIncludeLeaves = false; + // Replace only if the types match or are static <-> dynamic casts. We do + // not support slices or reshapes. + // TODO: This could be extended to support IR such as: + // %0 = tensor.empty() : tensor<128xf32> + // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) + // %2 = tensor.expand_shape %1 ... + // %3 = tensor.insert_slice %2 into ... + config.followSameTypeOrCastsOnly = true; + SetVector emptyTensors = state.findValueInReverseUseDefChain( + source.get(), /*condition=*/ + [&](Value val) { return val.getDefiningOp(); }, + config); + + for (Value v : emptyTensors) { + Operation *emptyTensorOp = v.getDefiningOp(); + + // Find a suitable insertion point. If no suitable insertion point for + // the replacement can be found, skip this replacement. + Operation *insertionPoint = + findValidInsertionPoint(emptyTensorOp, neededValues); + if (!insertionPoint) continue; - // Find tensor.empty ops on the reverse SSA use-def chain. Only follow - // equivalent tensors. I.e., stop when there are ops such as extract_slice - // on the path. - TraversalConfig config; - config.followEquivalentOnly = true; - config.alwaysIncludeLeaves = false; - // Replace only if the types match or are static <-> dynamic casts. We do - // not support slices or reshapes. - // TODO: This could be extended to support IR such as: - // %0 = tensor.empty() : tensor<128xf32> - // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) - // %2 = tensor.expand_shape %1 ... - // %3 = tensor.insert_slice %2 into ... - config.followSameTypeOrCastsOnly = true; - SetVector emptyTensors = state.findValueInReverseUseDefChain( - operand.get(), /*condition=*/ - [&](Value val) { return val.getDefiningOp(); }, - config); - - for (Value v : emptyTensors) { - Operation *emptyTensorOp = v.getDefiningOp(); - - // Find a suitable insertion point. If no suitable insertion point for - // the replacement can be found, skip this replacement. - Operation *insertionPoint = - findValidInsertionPoint(emptyTensorOp, neededValues); - if (!insertionPoint) - continue; - - rewriter.setInsertionPoint(insertionPoint); - Value replacement = - rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand); - if (!replacement) - continue; - if (replacement.getType() != v.getType()) { - rewriter.setInsertionPointAfterValue(replacement); - replacement = rewriter.create(v.getLoc(), v.getType(), - replacement); - } - // Replace the tensor::EmptyOp. - rewriter.replaceOp(emptyTensorOp, replacement); - state.resetCache(); + rewriter.setInsertionPoint(insertionPoint); + Value replacement = + op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + if (!replacement) + continue; + if (replacement.getType() != v.getType()) { + rewriter.setInsertionPointAfterValue(replacement); + replacement = rewriter.create(v.getLoc(), v.getType(), + replacement); } + // Replace the tensor::EmptyOp. + rewriter.replaceOp(emptyTensorOp, replacement); + state.resetCache(); } - }); - - return success(); -} - -/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be -/// eliminated if it is eventually inserted into another tensor (and some other -/// conditions are met). -/// -/// E.g.: -/// %0 = tensor.empty() -/// %1 = linalg.fill(%cst, %0) {inplace = [true]} -/// %2 = tensor.insert_slice %1 into %t[10][20][1] -/// -/// tensor::EmptyOp elimination will try to fill %t inplace instead of filling a -/// new allocation %0 and inserting it into %t. This is done by replacing the -/// tensor::EmptyOp with: -/// -/// %0 = tensor.extract_slice %t[10][20][1] -/// -/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets -/// those bufferize inplace in the absence of other conflicts. -/// -/// Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert -/// source's reverse use-def chain is eliminated if: -/// * On the reverse use-def chain path from the InsertSliceOp to the -/// tensor::EmptyOp, all ops were decided to bufferize inplace and the buffer -/// relation is "equivalent" (TODO: can be relaxed if needed). -/// * The reverse use-def chain has exactly one end, which is the -/// tensor::EmptyOp. -template -static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { - return eliminateEmptyTensors( - rewriter, op, state, - /*anchorMatchFunc=*/ - [&](OpOperand &operand, SmallVector &neededValues) { - auto insertSliceOp = dyn_cast(operand.getOwner()); - if (!insertSliceOp) - return false; - if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) - return false; - // Collect all values that are needed to construct the replacement op. - neededValues.append(insertSliceOp.getOffsets().begin(), - insertSliceOp.getOffsets().end()); - neededValues.append(insertSliceOp.getSizes().begin(), - insertSliceOp.getSizes().end()); - neededValues.append(insertSliceOp.getStrides().begin(), - insertSliceOp.getStrides().end()); - neededValues.push_back(insertSliceOp.getDest()); - - return true; - }, - /*rewriteFunc=*/ - [](OpBuilder &b, Location loc, OpOperand &operand) { - auto insertOp = cast(operand.getOwner()); - auto extractOp = b.create( - loc, insertOp.getSourceType(), insertOp.getDest(), - insertOp.getMixedOffsets(), insertOp.getMixedSizes(), - insertOp.getMixedStrides()); - return extractOp.getResult(); - }); -} + return WalkResult::advance(); + }); -LogicalResult -mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { - if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< - tensor::InsertSliceOp>(rewriter, op, state))) - return failure(); - if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< - tensor::ParallelInsertSliceOp>(rewriter, op, state))) - return failure(); return success(); } @@ -276,8 +190,7 @@ void EmptyTensorElimination::runOnOperation() { } IRRewriter rewriter(op->getContext()); - if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( - rewriter, op, state))) + if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp index 1156f2501a96e..dff9f64169d49 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp @@ -34,6 +34,31 @@ bool isSubsetEquivalentToInsertSliceLikeOp( isEqualConstantIntOrValue); } +template +Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc, + OpTy insertSliceOp) { + auto extractOp = b.create( + loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(), + insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), + insertSliceOp.getMixedStrides()); + return extractOp.getResult(); +} + +template +SmallVector +getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) { + SmallVector neededValues; + // Collect all values that are needed to construct the replacement op. + neededValues.append(insertSliceOp.getOffsets().begin(), + insertSliceOp.getOffsets().end()); + neededValues.append(insertSliceOp.getSizes().begin(), + insertSliceOp.getSizes().end()); + neededValues.append(insertSliceOp.getStrides().begin(), + insertSliceOp.getStrides().end()); + neededValues.push_back(insertSliceOp.getDest()); + return neededValues; +} + struct InsertSliceOpInterface : public SubsetInsertionOpInterface::ExternalModel { @@ -48,6 +73,18 @@ struct InsertSliceOpInterface return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate, equivalenceFn); } + + Value buildSubsetExtraction(Operation *op, OpBuilder &builder, + Location loc) const { + return buildSubsetExtractionOfInsertSliceLikeOp( + builder, loc, cast(op)); + } + + SmallVector + getValuesNeededToBuildSubsetExtraction(Operation *op) const { + return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp( + cast(op)); + } }; struct ParallelInsertSliceOpInterface @@ -68,6 +105,18 @@ struct ParallelInsertSliceOpInterface return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate, equivalenceFn); } + + Value buildSubsetExtraction(Operation *op, OpBuilder &builder, + Location loc) const { + return buildSubsetExtractionOfInsertSliceLikeOp( + builder, loc, cast(op)); + } + + SmallVector + getValuesNeededToBuildSubsetExtraction(Operation *op) const { + return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp( + cast(op)); + } }; } // namespace