diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 66382f29c2424..84f7dec2f4003 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { The method returns the operation that is the tiled implementation. }], - /*retType=*/"FailureOr", + /*retType=*/"FailureOr<::mlir::TilingResult>", /*methodName=*/"getTiledImplementation", /*args=*/(ins "OpBuilder &":$b, @@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> { by the tiled implementation. Expects the same `offsets` and `sizes` as used to obtain the tiled implementation of the operation. }], - /*retType=*/"LogicalResult", + /*retType=*/"::mlir::LogicalResult", /*methodName=*/"getResultTilePosition", /*args=*/(ins "OpBuilder &":$b, "unsigned":$resultNumber, "ArrayRef ":$offsets, "ArrayRef ":$sizes, - "SmallVector &":$resultOffsets, - "SmallVector &":$resultSizes), + "SmallVectorImpl &":$resultOffsets, + "SmallVectorImpl &":$resultSizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Method to return the position of iteration domain tile computed by the + tiled operation. + }], + /*retType=*/"::mlir::LogicalResult", + /*methodName=*/"getIterationDomainTileFromOperandTile", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$operandNumber, + "ArrayRef ":$offsets, + "ArrayRef ":$sizes, + "SmallVectorImpl &":$iterDomainOffsets, + "SmallVectorImpl &":$iterDomainSizes), /*methodBody=*/"", /*defaultImplementation=*/[{ return failure(); @@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { iteration space). - `sizes` provides the size of the tile. }], - /*retType=*/"FailureOr", + /*retType=*/"FailureOr<::mlir::TilingResult>", /*methodName=*/"generateResultTileValue", /*args=*/(ins "OpBuilder &":$b, @@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> { return failure(); }] >, + InterfaceMethod< + /*desc=*/[{ + Method to generate the tiled implementation of an operation from + operand tile position. + + Generates the IR that computes the tiled implementation of an + operation from operand tile. The `offsets` and `sizes` + describe the tile of the operand required. This is different from + `getTiledImplementation` which generates the tiled + implementation of the operation given a tile of the + iteration space. This method generates a tiled + implementation of the operation based on the tile of the + operand required. This method enables consumer fusion by using + tile and fuse. The method returns failure if the operation + can't be tiled to generate the operand tile. In practical terms + this implies it cannot be tiled and fused with its producers. + + - `offsets` provides the offset of the tile in the coordinate system + of the original iteration space, i.e., if an iteration space + dimension had non-zero offset, it must be included in the offset + provided here (as opposed to zero-based offset "relative" to the + iteration space). + - `sizes` provides the size of the tile. + }], + /*retType=*/"FailureOr<::mlir::TilingResult>", + /*methodName=*/"getTiledImplementationFromOperandTile", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$operandNumber, + "ArrayRef":$offsets, + "ArrayRef":$sizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] + >, InterfaceMethod< /*desc=*/[{ Generates the scalar implementation of the operation. @@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { transformations are done, this method can be used to lower to scalar code that can then be lowered to LLVM or SPIR-V dialects. }], - /*retType=*/"LogicalResult", + /*retType=*/"::mlir::LogicalResult", /*methodName=*/"generateScalarImplementation", /*args=*/(ins "OpBuilder &":$b, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 9c5c58fa1fabf..e9999c34d0fac 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder, LogicalResult SoftmaxOp::getResultTilePosition( OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { + ArrayRef sizes, SmallVectorImpl &resultOffsets, + SmallVectorImpl &resultSizes) { if (resultNumber == 0) { resultOffsets.assign(offsets.begin(), offsets.end()); resultSizes.assign(sizes.begin(), sizes.end()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index bd870d4f982e5..71e9c3771dcde 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -110,7 +110,7 @@ struct LinalgOpTilingInterface })); } - // Instantiate the tiled implementation of the operation. + /// Instantiate the tiled implementation of the operation. FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, @@ -132,14 +132,66 @@ struct LinalgOpTilingInterface return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; } - // Return the details of the output tile generated by the tiled - // implementation. + void + getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, + ArrayRef offsets, + ArrayRef sizes, + SmallVectorImpl &mappedOffsets, + SmallVectorImpl &mappedSizes) const { + unsigned numLoops = linalgOp.getNumLoops(); + auto tilingInterfaceOp = cast(linalgOp.getOperation()); + mappedOffsets.resize(numLoops); + mappedSizes.resize(numLoops); + if (!indexingMap.isPermutation()) { + SmallVector iterationDomain = + tilingInterfaceOp.getIterationDomain(b); + for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { + mappedOffsets[index] = value.offset; + mappedSizes[index] = value.size; + } + } + for (const auto &&[index, value] : + llvm::enumerate(indexingMap.getResults())) { + unsigned dimPosition = cast(value).getPosition(); + mappedOffsets[dimPosition] = offsets[index]; + mappedSizes[dimPosition] = sizes[index]; + } + } + + /// Return the details of the output tile generated by the tiled + /// implementation. + LogicalResult getIterationDomainTileFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl &iterDomainOffsets, + SmallVectorImpl &iterDomainSizes) const { + auto linalgOp = cast(op); + + // Check that the indexing map used for the operand is a projected + // permutation. This could be relaxed with a more general approach that can + // map the offsets and sizes from the operand to iteration space tiles + // (filling in full extent for dimensions not used to access the result). + AffineMap indexingMap = + linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber)); + if (!indexingMap.isProjectedPermutation()) { + return emitError(op->getLoc(), + "unhandled get iter domain position when operand is not " + "accessed using a permuted projection"); + } + + getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, + iterDomainOffsets, iterDomainSizes); + return success(); + } + + /// Return the details of the output tile generated by the tiled + /// implementation. LogicalResult getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, - SmallVector &resultOffsets, - SmallVector &resultSizes) const { + SmallVectorImpl &resultOffsets, + SmallVectorImpl &resultSizes) const { Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); @@ -160,6 +212,21 @@ struct LinalgOpTilingInterface return success(); } + FailureOr getTiledImplementationFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes) const { + SmallVector mappedOffsets, mappedSizes; + auto tilingInterfaceOp = cast(op); + if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile( + b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) { + return emitError( + op->getLoc(), + "unable to obtain the iter domain position of the operation."); + } + return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, + mappedSizes); + } + FailureOr generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, @@ -177,29 +244,16 @@ struct LinalgOpTilingInterface "unhandled tiled implementation generation when result is not " "accessed using a permuted projection"); } - - auto numLoops = linalgOp.getNumLoops(); + SmallVector mappedOffsets, mappedSizes; + getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, + mappedOffsets, mappedSizes); auto tilingInterfaceOp = cast(op); - SmallVector iterationTileOffsets(numLoops), - iterationTileSizes(numLoops); - if (!indexingMap.isPermutation()) { - SmallVector iterationDomain = - tilingInterfaceOp.getIterationDomain(b); - for (const auto &range : llvm::enumerate(iterationDomain)) { - iterationTileOffsets[range.index()] = range.value().offset; - iterationTileSizes[range.index()] = range.value().size; - } - } - for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) { - unsigned dimPosition = - cast(resultExpr.value()).getPosition(); - iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; - iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; - } - FailureOr tilingResult = - tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets, - iterationTileSizes); + tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); + + if (failed(tilingResult)) + return failure(); + if (tilingResult->tiledOps.size() != 1) return op->emitOpError("failed to generate tiled implementation"); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index d25efcf50ec56..296c5fc7a5c2b 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel { getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, - SmallVector &resultOffsets, - SmallVector &resultSizes) const { + SmallVectorImpl &resultOffsets, + SmallVectorImpl &resultSizes) const { resultOffsets.assign(offsets.begin(), offsets.end()); resultSizes.assign(sizes.begin(), sizes.end()); return success(); @@ -199,8 +199,8 @@ struct PackOpTiling getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, - SmallVector &resultOffsets, - SmallVector &resultSizes) const { + SmallVectorImpl &resultOffsets, + SmallVectorImpl &resultSizes) const { // The iteration domain is over outer dimensions of packed layout. In this // context, the outer dimensions of `resultOffsets` are `offsets`. The // inner dimensions of `resultOffsets` are zeros because tiling is not @@ -452,8 +452,8 @@ struct UnPackOpTiling getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, - SmallVector &resultOffsets, - SmallVector &resultSizes) const { + SmallVectorImpl &resultOffsets, + SmallVectorImpl &resultSizes) const { resultOffsets = llvm::to_vector(offsets); resultSizes = llvm::to_vector(sizes); return success();