diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h index 6fdb72c370e6d..2091faa6b0b02 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -410,9 +410,11 @@ void canonicalizeSetAndOperands(IntegerSet *set, /// other AffineApplyOps supplying those operands. The operands of the resulting /// AffineApplyOp do not change the length of AffineApplyOp chains. AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, - ArrayRef operands); + ArrayRef operands, + bool composeAffineMin = false); AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e, - ArrayRef operands); + ArrayRef operands, + bool composeAffineMin = false); /// Constructs an AffineApplyOp that applies `map` to `operands` after composing /// the map with the maps of any other AffineApplyOp supplying the operands, @@ -421,16 +423,19 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e, /// map. OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, - ArrayRef operands); + ArrayRef operands, + bool composeAffineMin = false); /// Variant of `makeComposedFoldedAffineApply` that applies to an expression. OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineExpr expr, - ArrayRef operands); + ArrayRef operands, + bool composeAffineMin = false); /// Variant of `makeComposedFoldedAffineApply` suitable for multi-result maps. /// Note that this may create as many affine.apply operations as the map has /// results given that affine.apply must be single-result. SmallVector makeComposedFoldedMultiResultAffineApply( - OpBuilder &b, Location loc, AffineMap map, ArrayRef operands); + OpBuilder &b, Location loc, AffineMap map, ArrayRef operands, + bool composeAffineMin = false); /// Returns an AffineMinOp obtained by composing `map` and `operands` with /// AffineApplyOps supplying those operands. @@ -459,7 +464,8 @@ OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, /// terminal symbol, i.e., a symbol defined at the top level or a block/function /// argument. void fullyComposeAffineMapAndOperands(AffineMap *map, - SmallVectorImpl *operands); + SmallVectorImpl *operands, + bool composeAffineMin = false); } // namespace affine } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 48770d4f4ff7b..3b4d51d914d86 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -11,12 +11,14 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" @@ -26,7 +28,9 @@ #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" +#include #include #include @@ -1042,6 +1046,62 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef operands) { map.getContext()); } +/// Assuming `dimOrSym` is a quantity in `map` that is defined by `minOp`. +/// Assuming that the quantity is of the form: +/// `affine_min(f(x, y), symbolic_cst)`. +/// This function checks that `0 < affine_min(f(x, y), symbolic_cst)` and +/// proceeds with replacing the patterns: +/// ``` +/// dimOrSym.ceildiv(symbolic_cst) +/// (dimOrSym + symbolic_cst - 1).floordiv(symbolic_cst) +/// ``` +/// by `1`. +/// +/// Additionally, allows the caller to pass `affineMinKnownToBeNonNegative` to +/// inject static information that may not be statically discoverable. +/// +/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check +/// for the nonnegative case, if `affineMinKnownToBeNonNegative` is false. +static LogicalResult replaceAffineMinBoundingBoxExpression( + AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, + bool affineMinKnownToBeNonNegative = false) { + auto affineMinMap = minOp.getAffineMap(); + if (!affineMinKnownToBeNonNegative) { + ValueRange values = minOp->getOperands(); + for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) { + AffineMap row = affineMinMap.getSubMap(ArrayRef{i}); + FailureOr lowerBound = + ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::LB, {row, values}, + /*stopCondition=*/nullptr, + /*closedUB=*/true); + if (failed(lowerBound) || lowerBound.value() <= 0) + return failure(); + } + } + + AffineMap initialMap = *map; + for (unsigned i = 0, e = affineMinMap.getNumResults(); i != e; ++i) { + auto m = affineMinMap.getSubMap(ArrayRef{i}); + AffineExpr expr = m.getResult(0); + if (!expr.isSymbolicOrConstant()) + continue; + + DenseMap repl; + // dimOrSym.ceilDiv(expr) -> 1 + repl[dimOrSym.ceilDiv(expr)] = getAffineConstantExpr(1, minOp.getContext()); + // (dimOrSym + expr - 1).floorDiv(expr) -> 1 + repl[(dimOrSym + expr - 1).floorDiv(expr)] = + getAffineConstantExpr(1, minOp.getContext()); + auto newMap = map->replace(repl); + if (newMap == *map) + continue; + *map = newMap; + } + + return success(*map != initialMap); +} + /// Replace all occurrences of AffineExpr at position `pos` in `map` by the /// defining AffineApplyOp expression and operands. /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. @@ -1052,10 +1112,13 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef operands) { /// 2. `map` dim and symbols are gradually shifted to higher positions. /// 3. Old `dim` and `sym` entries are replaced by nullptr /// This avoids the need for any bookkeeping. +/// If `replaceAffineMin` is set to true, additionally triggers more expensive +/// replacements involving affine_min operations. static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl &dims, - SmallVectorImpl &syms) { + SmallVectorImpl &syms, + bool replaceAffineMin) { MLIRContext *ctx = map->getContext(); bool isDimReplacement = (dimOrSymbolPosition < dims.size()); unsigned pos = isDimReplacement ? dimOrSymbolPosition @@ -1064,6 +1127,13 @@ static LogicalResult replaceDimOrSym(AffineMap *map, if (!v) return failure(); + auto minOp = v.getDefiningOp(); + if (minOp && replaceAffineMin) { + AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx) + : getAffineSymbolExpr(pos, ctx); + return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map); + } + auto affineApply = v.getDefiningOp(); if (!affineApply) return failure(); @@ -1101,7 +1171,8 @@ static LogicalResult replaceDimOrSym(AffineMap *map, /// iteratively. Perform canonicalization of map and operands as well as /// AffineMap simplification. `map` and `operands` are mutated in place. static void composeAffineMapAndOperands(AffineMap *map, - SmallVectorImpl *operands) { + SmallVectorImpl *operands, + bool composeAffineMin = false) { if (map->getNumResults() == 0) { canonicalizeMapAndOperands(map, operands); *map = simplifyAffineMap(*map); @@ -1122,7 +1193,8 @@ static void composeAffineMapAndOperands(AffineMap *map, while (true) { bool changed = false; for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos) - if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms)))) + if ((changed |= + succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin)))) break; if (!changed) break; @@ -1163,38 +1235,41 @@ static void composeAffineMapAndOperands(AffineMap *map, } void mlir::affine::fullyComposeAffineMapAndOperands( - AffineMap *map, SmallVectorImpl *operands) { + AffineMap *map, SmallVectorImpl *operands, bool composeAffineMin) { while (llvm::any_of(*operands, [](Value v) { return isa_and_nonnull(v.getDefiningOp()); })) { - composeAffineMapAndOperands(map, operands); + composeAffineMapAndOperands(map, operands, composeAffineMin); } } AffineApplyOp mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, - ArrayRef operands) { + ArrayRef operands, + bool composeAffineMin) { SmallVector valueOperands; map = foldAttributesIntoMap(b, map, operands, valueOperands); - composeAffineMapAndOperands(&map, &valueOperands); + composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin); assert(map); return b.create(loc, map, valueOperands); } AffineApplyOp mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e, - ArrayRef operands) { + ArrayRef operands, + bool composeAffineMin) { return makeComposedAffineApply( b, loc, AffineMap::inferFromExprList(ArrayRef{e}, b.getContext()) .front(), - operands); + operands, composeAffineMin); } /// Composes the given affine map with the given list of operands, pulling in /// the maps from any affine.apply operations that supply the operands. static void composeMultiResultAffineMap(AffineMap &map, - SmallVectorImpl &operands) { + SmallVectorImpl &operands, + bool composeAffineMin = false) { // Compose and canonicalize each expression in the map individually because // composition only applies to single-result maps, collecting potentially // duplicate operands in a single list with shifted dimensions and symbols. @@ -1203,7 +1278,8 @@ static void composeMultiResultAffineMap(AffineMap &map, for (unsigned i : llvm::seq(0, map.getNumResults())) { SmallVector submapOperands(operands.begin(), operands.end()); AffineMap submap = map.getSubMap({i}); - fullyComposeAffineMapAndOperands(&submap, &submapOperands); + fullyComposeAffineMapAndOperands(&submap, &submapOperands, + composeAffineMin); canonicalizeMapAndOperands(&submap, &submapOperands); unsigned numNewDims = submap.getNumDims(); submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size()); @@ -1221,10 +1297,9 @@ static void composeMultiResultAffineMap(AffineMap &map, canonicalizeMapAndOperands(&map, &operands); } -OpFoldResult -mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, - AffineMap map, - ArrayRef operands) { +OpFoldResult mlir::affine::makeComposedFoldedAffineApply( + OpBuilder &b, Location loc, AffineMap map, ArrayRef operands, + bool composeAffineMin) { assert(map.getNumResults() == 1 && "building affine.apply with !=1 result"); // Create new builder without a listener, so that no notification is @@ -1236,7 +1311,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, // Create op. AffineApplyOp applyOp = - makeComposedAffineApply(newBuilder, loc, map, operands); + makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin); // Get constant operands. SmallVector constOperands(applyOp->getNumOperands()); @@ -1256,26 +1331,25 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, return llvm::getSingleElement(foldResults); } -OpFoldResult -mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, - AffineExpr expr, - ArrayRef operands) { +OpFoldResult mlir::affine::makeComposedFoldedAffineApply( + OpBuilder &b, Location loc, AffineExpr expr, + ArrayRef operands, bool composeAffineMin) { return makeComposedFoldedAffineApply( b, loc, AffineMap::inferFromExprList(ArrayRef{expr}, b.getContext()) .front(), - operands); + operands, composeAffineMin); } SmallVector mlir::affine::makeComposedFoldedMultiResultAffineApply( - OpBuilder &b, Location loc, AffineMap map, - ArrayRef operands) { - return llvm::map_to_vector(llvm::seq(0, map.getNumResults()), - [&](unsigned i) { - return makeComposedFoldedAffineApply( - b, loc, map.getSubMap({i}), operands); - }); + OpBuilder &b, Location loc, AffineMap map, ArrayRef operands, + bool composeAffineMin) { + return llvm::map_to_vector( + llvm::seq(0, map.getNumResults()), [&](unsigned i) { + return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}), + operands, composeAffineMin); + }); } template @@ -3024,7 +3098,8 @@ void AffineIfOp::build(OpBuilder &builder, OperationState &result, /// `set` by composing the maps of such affine.apply ops with the integer /// set constraints. static void composeSetAndOperands(IntegerSet &set, - SmallVectorImpl &operands) { + SmallVectorImpl &operands, + bool composeAffineMin = false) { // We will simply reuse the API of the map composition by viewing the LHSs of // the equalities and inequalities of `set` as the affine exprs of an affine // map. Convert to equivalent map, compose, and convert back to set. @@ -3035,7 +3110,7 @@ static void composeSetAndOperands(IntegerSet &set, [](Value v) { return v.getDefiningOp(); })) return; - composeAffineMapAndOperands(&map, &operands); + composeAffineMapAndOperands(&map, &operands, composeAffineMin); set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(), set.getEqFlags()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 5383ae48aeb3a..42dac0776bace 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -84,7 +84,7 @@ SmallVector linalg::computePaddedShape( getDimsToSize(rewriter, indexingSizes, options); // For each dimension in the operand's shape, iterate over indexingSizes and - // add + // add the various term contributions. for (const auto &enResults : enumerate(indexingMap.getResults())) { int64_t resultIndex = enResults.index(); AffineMap partialIndexingMap = indexingMap.getSubMap( @@ -122,7 +122,8 @@ SmallVector linalg::computePaddedShape( AffineMap composedMap = projectedMap.compose(ceilMap); OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, composedMap, - {indexingSizes[paddingDim], paddingSize}); + {indexingSizes[paddingDim], paddingSize}, + /*composeAffineMin=*/true); terms.push_back(paddingDimOfr); } else { // Otherwise just set to paddingSize. diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir index 5ac35c14be3fb..845fe25193019 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir @@ -19,7 +19,7 @@ func.func @pad_lhs( // CHECK: : tensor to tensor // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor, tensor<12x25xf32>) outs(%{{.*}} : tensor) -> tensor - + // CHECK: tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1] // CHECK: : tensor to tensor // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1] @@ -29,8 +29,8 @@ func.func @pad_lhs( } module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op // Tile to 5 then pad to 8 (supposedly to better hit vector ops). @@ -71,13 +71,13 @@ module { return %0 : tensor<7x11x12xf32> } module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.any_op %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [3, 5] pad_to_multiple_of { - padding_dimensions = [0, 2], + padding_dimensions = [0, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield + transform.yield } } } @@ -126,13 +126,155 @@ module { return %0 : tensor } module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.any_op %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [3, 5] pad_to_multiple_of { - padding_dimensions = [0, 2], + padding_dimensions = [0, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield + transform.yield } } } + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> ((s0 ceildiv 16) * 16)> +// CHECK-LABEL: pad_lhs +func.func @pad_lhs( + %arg0: tensor<24x?xf32>, %arg1: tensor, %arg2: tensor<24x25xf32>) + -> tensor<24x25xf32> +{ + // CHECK: %[[D0_0:.*]] = tensor.dim + // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]]] + // CHECK: tensor.pad %{{.*}} low[0, 0] high[0, %[[H0]]] + // CHECK: : tensor<24x?xf32> to tensor<24x?xf32> + + // CHECK: %[[D0_2:.*]] = tensor.dim + // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D0_2]]] + // CHECK: tensor.pad %{{.*}} low[0, 0] high[%[[H1]], 0] + // CHECK: : tensor to tensor + // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) + + // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}}: tensor<8x16xf32>, tensor<16x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32> + + // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0] [8, 25] [1, 1] + // CHECK-SAME: : tensor<8x25xf32> into tensor<24x25xf32> + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x?xf32>, tensor) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + // Pad then tile should produce static shapes. + %matmul_padded, %_ = transform.structured.pad_tiling_interface %matmul to padding_sizes [8, 16] pad_to_multiple_of { + padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 2] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %m, %l0, %l1 = transform.structured.tile_using_for %matmul_padded tile_sizes [8, 0, 16] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + %func = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + %func2 = transform.apply_registered_pass "resolve-shaped-type-result-dims" to %func + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func2 { + transform.apply_patterns.canonicalization + } {apply_cse} : !transform.any_op + %minmax = transform.structured.match ops{["affine.min", "affine.max"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.affine.simplify_min_max_affine_ops %minmax : !transform.any_op + transform.apply_patterns to %func2 { + transform.apply_patterns.canonicalization + } {apply_cse} : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (-d0 + 16)> + +// CHECK-LABEL: pad_lhs +func.func @pad_lhs( + %arg0: tensor<24x?xf32>, %arg1: tensor, %arg2: tensor<24x25xf32>) + -> tensor<24x25xf32> +{ + // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) + // CHECK: %[[MIN:.*]] = affine.min #[[$MAP0]](%{{.*}}) + // CHECK: %[[H0:.*]] = affine.apply #[[$MAP1]](%[[MIN]]) + // CHECK: tensor.pad %{{.*}} low[0, 0] high[0, %[[H0]]] + // CHECK: : tensor<8x?xf32> to tensor<8x16xf32> + + // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]](%[[MIN]]) + // CHECK: tensor.pad %{{.*}} low[0, 0] high[%[[H1]], 0] + // CHECK: : tensor to tensor<16x25xf32> + + // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x16xf32>, tensor<16x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32> + + // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0] [8, 25] [1, 1] + // CHECK-SAME: : tensor<8x25xf32> into tensor<24x25xf32> + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x?xf32>, tensor) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + // Tile then pad should produce static shapes. + %m, %l0, %l1 = transform.structured.tile_using_for %matmul tile_sizes [8, 0, 16] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + %matmul_padded, %_ = transform.structured.pad_tiling_interface %m to padding_sizes [8, 16] pad_to_multiple_of { + padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 2] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + transform.yield + } +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (-d0 + 20, 8)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (-d0 + 8)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (-d0 + 16)> + +// CHECK-LABEL: pad_lhs +func.func @pad_lhs( + %arg0: tensor<20x?xf32>, %arg1: tensor, %arg2: tensor<20x25xf32>) + -> tensor<20x25xf32> +{ + // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x16xf32>, tensor<16x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32> + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<20x?xf32>, tensor) outs(%arg2 : tensor<20x25xf32>) -> tensor<20x25xf32> + func.return %0 : tensor<20x25xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + // Tile then pad should produce static shapes. + %m, %l0, %l1 = transform.structured.tile_using_for %matmul tile_sizes [8, 0, 16] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + %matmul_padded, %_ = transform.structured.pad_tiling_interface %m to padding_sizes [8, 16] pad_to_multiple_of { + padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 2] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + transform.yield + } +} +