diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td index 70b127fd063ca..2969b4238dd67 100644 --- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td +++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td @@ -63,4 +63,35 @@ def SimplifyBoundedAffineOpsOp }]; } +def SimplifyMinMaxAffineOpsOp : + Op, + DeclareOpInterfaceMethods + ]> { + let description = [{ + Simplify the targeted `affine.min` / `affine.max` ops using the + `mlir::affine::simplifyAffineMinMaxOps` transform. + + Example: + ``` + %0 = transform.structured.match ops{["affine.max"]} in %arg1 + transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op + ``` + + #### Return modes + + This transform consumes the target handle and does not produce any results. + This transforms definitely fails if any of the targeted operations is not an + `affine.min` or `affine.max` operation, or if the canonicalization patterns + failed to converge. + This transform silently fails if none of the operations were simplified. + Otherwise, it succeeds. + }]; + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + let assemblyFormat = [{ + $target attr-dict `:` type($target) + }]; +} + #endif // Affine_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h index 5c538d28c1835..272054448374e 100644 --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -34,6 +34,8 @@ namespace affine { class AffineApplyOp; class AffineDelinearizeIndexOp; class AffineLinearizeIndexOp; +class AffineMaxOp; +class AffineMinOp; /// Lowers `affine.delinearize_index` into a sequence of division and remainder /// operations. @@ -127,6 +129,37 @@ OpFoldResult materializeComputedBound( OpBuilder &b, Location loc, AffineMap boundMap, ArrayRef>> mapOperands); +/// This transform tries to simplify the affine min operation `op`, by finding a +/// common lower bound for a set of expressions in the affine map results. It +/// returns whether the transform updated `op`'s affine map. +/// +/// In concrete terms, given an operation like: +/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]` +/// If `d0 < 128` and `128 < s1 < s0`, the transform will update `op` to: +/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`. +bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op); + +/// This transform tries to simplify the affine max operation `op`, by finding a +/// common upper bound for a set of expressions in the affine map results. It +/// returns whether the transform updated `op`'s affine map. +/// +/// In concrete terms, given an operation like: +/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]` +/// If `d0 > 128` and `s0 > s1 > 128`, the transform will update `op` to: +/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s0)>(%i)[%s0, %s1]`. +bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op); + +/// This transform applies `simplifyAffineMinOp` and `simplifyAffineMaxOp` to +/// all the `affine.min` or `affine.max` operations in `ops`. After +/// simplification, it invokes the `affine.min/max` canonicalization patterns on +/// `ops`. +/// +/// This transform returns failure if the greedy pattern rewriter failed to +/// converge during canonicalization, otherwise it returns success. If provided, +/// `modified` is set to `true` if the IR was modified in any way. +LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter, + ArrayRef ops, + bool *modified = nullptr); } // namespace affine } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 337314143c80c..d168735f50598 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -135,10 +135,17 @@ class ValueBoundsConstraintSet /// Construct a variable for a map and its operands. Variable(AffineMap map, ArrayRef mapOperands); - Variable(AffineMap map, ArrayRef mapOperands); + Variable(AffineMap map, ValueRange mapOperands); MLIRContext *getContext() const { return map.getContext(); } + /// Returns the affine map. + AffineMap getMap() const { return map; } + + /// Returns the map operands. + ValueDimList &getOperands() { return mapOperands; } + const ValueDimList &getOperands() const { return mapOperands; } + private: friend class ValueBoundsConstraintSet; AffineMap map; @@ -254,6 +261,12 @@ class ValueBoundsConstraintSet /// prove the relation or until it ran out of IR. static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs); + /// This function is similar to `ValueBoundsConstraintSet::compare`, except + /// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the + /// relation nor its inverse relation could be proven. + static llvm::FailureOr strongCompare(const Variable &lhs, + ComparisonOperator cmp, + const Variable &rhs); /// Compute whether the given variables are equal. Return "failure" if /// equality could not be determined. @@ -327,6 +340,16 @@ class ValueBoundsConstraintSet /// constraints. bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos); + /// Return "true" if, based on the current state of the constraint system, + /// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)" + /// can be proven. Otherwise, it returns `failure` if neither the relation nor + /// its inverse relation could be proven. + /// + /// This function does not analyze any IR and does not populate any additional + /// constraints. + llvm::FailureOr strongComparePos(int64_t lhsPos, ComparisonOperator cmp, + int64_t rhsPos); + /// Given an affine map with a single result (and map operands), add a new /// column to the constraint set that represents the result of the map. /// Traverse additional IR starting from the map operands as needed (as long diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index c9fe4474a68fa..b1e40d9b289ec 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -112,7 +113,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, } if (boundedOps.contains(target)) { auto diag = emitDefiniteFailure() - << "target op result must not be constrainted"; + << "target op result must not be constrained"; diag.attachNote(target->getLoc()) << "target/constrained op"; return diag; } @@ -148,6 +149,42 @@ void SimplifyBoundedAffineOpsOp::getEffects( modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// SimplifyMinMaxAffineOpsOp +//===----------------------------------------------------------------------===// +DiagnosedSilenceableFailure +SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter, + TransformResults &results, + TransformState &state) { + SmallVector targets; + for (Operation *target : state.getPayloadOps(getTarget())) { + if (!isa(target)) { + auto diag = emitDefiniteFailure() + << "target must be affine.min or affine.max"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + targets.push_back(target); + } + bool modified = false; + if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets, + &modified))) { + return emitDefiniteFailure() + << "affine.min/max simplification did not converge"; + } + if (!modified) { + return emitSilenceableError() + << "the transform failed to simplify any of the target operations"; + } + return DiagnosedSilenceableFailure::success(); +} + +void SimplifyMinMaxAffineOpsOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTargetMutable(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt index 1c82822b2bd7f..c792200f4a49a 100644 --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms ReifyValueBounds.cpp SuperVectorize.cpp SimplifyAffineStructures.cpp + SimplifyAffineMinMax.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp new file mode 100644 index 0000000000000..c992badcfa493 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp @@ -0,0 +1,174 @@ +//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a transform to simplify mix/max affine operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/IntEqClasses.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "affine-min-max" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") + +using namespace mlir; +using namespace mlir::affine; + +/// Simplifies an affine min/max operation by proving there's a lower or upper +/// bound. +template +static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { + using Variable = ValueBoundsConstraintSet::Variable; + using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator; + + AffineMap affineMap = affineOp.getMap(); + ValueRange operands = affineOp.getOperands(); + static constexpr bool isMin = std::is_same_v; + + LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; }); + + // Create a `Variable` list with values corresponding to each of the results + // in the affine affineMap. + SmallVector variables = llvm::map_to_vector( + llvm::iota_range(0u, affineMap.getNumResults(), false), + [&](unsigned i) { + return Variable(affineMap.getSliceMap(i, 1), operands); + }); + + // Get the comparison operation. + ComparisonOperator cmpOp = + isMin ? ComparisonOperator::LT : ComparisonOperator::GT; + + // Find disjoint sets bounded by a common value. + llvm::IntEqClasses boundedClasses(variables.size()); + DenseMap bounds; + for (auto &&[i, v] : llvm::enumerate(variables)) { + unsigned eqClass = boundedClasses.findLeader(i); + + // If the class already has a bound continue. + if (bounds.contains(eqClass)) + continue; + + // Initialize the bound. + Variable *bound = &v; + + LLVM_DEBUG({ + DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() + << "`\n"; + }); + + // Check against the other variables. + for (size_t j = i + 1; j < variables.size(); ++j) { + unsigned jEqClass = boundedClasses.findLeader(j); + // Skip if the class is the same. + if (jEqClass == eqClass) + continue; + + // Get the bound of the equivalence class or itself. + Variable *nv = bounds.lookup_or(jEqClass, &variables[j]); + + LLVM_DEBUG({ + DBGS() << "- comparing with variable: #" << jEqClass + << ", with map: " << nv->getMap() << "\n"; + }); + + // Compare the variables. + FailureOr cmpResult = + ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv); + + // The variables cannot be compared. + if (failed(cmpResult)) { + LLVM_DEBUG({ + DBGS() << "-- classes: #" << i << ", #" << jEqClass + << " cannot be merged\n"; + }); + continue; + } + + // Join the equivalent classes and update the bound if necessary. + LLVM_DEBUG({ + DBGS() << "-- merging classes: #" << i << ", #" << jEqClass + << ", is cmp(lhs, rhs): " << *cmpResult << "`\n"; + }); + if (*cmpResult) { + boundedClasses.join(eqClass, jEqClass); + } else { + // In this case we have lhs > rhs if isMin == true, or lhs < rhs if + // isMin == false. + bound = nv; + boundedClasses.join(eqClass, jEqClass); + } + } + bounds[boundedClasses.findLeader(i)] = bound; + } + + // Return if there's no simplification. + if (bounds.size() >= affineMap.getNumResults()) { + LLVM_DEBUG( + { DBGS() << "- the affine operation couldn't get simplified\n"; }); + return false; + } + + // Construct the new affine affineMap. + SmallVector results; + results.reserve(bounds.size()); + for (auto [k, bound] : bounds) + results.push_back(bound->getMap().getResult(0)); + + affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(), + results, rewriter.getContext()); + + // Update the affine op. + rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); }); + LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; }); + return true; +} + +bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) { + return simplifyAffineMinMaxOp(rewriter, op); +} + +bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) { + return simplifyAffineMinMaxOp(rewriter, op); +} + +LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter, + ArrayRef ops, + bool *modified) { + bool changed = false; + for (Operation *op : ops) { + if (auto minOp = dyn_cast(op)) + changed = simplifyAffineMinOp(rewriter, minOp) || changed; + else if (auto maxOp = cast(op)) + changed = simplifyAffineMaxOp(rewriter, maxOp) || changed; + } + RewritePatternSet patterns(rewriter.getContext()); + AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext()); + AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext()); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + if (modified) + *modified = changed; + // Canonicalize to a fixpoint. + if (failed(applyOpPatternsGreedily( + ops, frozenPatterns, + GreedyRewriteConfig() + .setListener( + static_cast(rewriter.getListener())) + .setStrictness(GreedyRewriteStrictness::ExistingAndNewOps), + &changed))) { + return failure(); + } + if (modified) + *modified = changed; + return success(); +} diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index 87f883c2e6485..c9481fb5d9406 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -146,7 +146,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map, } ValueBoundsConstraintSet::Variable::Variable(AffineMap map, - ArrayRef mapOperands) + ValueRange mapOperands) : Variable(map, llvm::map_to_vector(mapOperands, [](Value v) { return Variable(v); })) {} @@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, return isEmpty; } +FailureOr ValueBoundsConstraintSet::strongComparePos( + int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) { + auto strongCmp = [&](ComparisonOperator cmp, + ComparisonOperator negCmp) -> FailureOr { + if (comparePos(lhsPos, cmp, rhsPos)) + return true; + if (comparePos(lhsPos, negCmp, rhsPos)) + return false; + return failure(); + }; + switch (cmp) { + case ComparisonOperator::LT: + return strongCmp(ComparisonOperator::LT, ComparisonOperator::GE); + case ComparisonOperator::LE: + return strongCmp(ComparisonOperator::LE, ComparisonOperator::GT); + case ComparisonOperator::GT: + return strongCmp(ComparisonOperator::GT, ComparisonOperator::LE); + case ComparisonOperator::GE: + return strongCmp(ComparisonOperator::GE, ComparisonOperator::LT); + case ComparisonOperator::EQ: { + std::optional le = + strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos); + if (!le) + return failure(); + if (!*le) + return false; + std::optional ge = + strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos); + if (!ge) + return failure(); + if (!*ge) + return false; + return true; + } + } + llvm_unreachable("invalid comparison operator"); +} + bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs) { @@ -763,14 +801,29 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs, return cstr.comparePos(lhsPos, cmp, rhsPos); } +FailureOr ValueBoundsConstraintSet::strongCompare(const Variable &lhs, + ComparisonOperator cmp, + const Variable &rhs) { + int64_t lhsPos = -1, rhsPos = -1; + auto stopCondition = [&](Value v, std::optional dim, + ValueBoundsConstraintSet &cstr) { + // Keep processing as long as lhs/rhs were not processed. + if (size_t(lhsPos) >= cstr.positionToValueDim.size() || + size_t(rhsPos) >= cstr.positionToValueDim.size()) + return false; + // Keep processing as long as the strong relation cannot be proven. + FailureOr ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos); + return failed(ordered) ? true : false; + }; + ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); + lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands); + rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands); + return cstr.strongComparePos(lhsPos, cmp, rhsPos); +} + FailureOr ValueBoundsConstraintSet::areEqual(const Variable &var1, const Variable &var2) { - if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2)) - return true; - if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) || - ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2)) - return false; - return failure(); + return strongCompare(var1, ComparisonOperator::EQ, var2); } FailureOr diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir new file mode 100644 index 0000000000000..948f434f3fa5e --- /dev/null +++ b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)> +// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)> +// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)> + +// CHECK: @min_max_full_simplify +func.func @min_max_full_simplify() -> (index, index) { + %0 = test.value_with_bounds {max = 128 : index, min = 0 : index} + %1 = test.value_with_bounds {max = 512 : index, min = 256 : index} + // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index} + // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index} + // CHECK-NOT: affine.min + // CHECK-NOT: affine.max + // CHECK: return %[[V0]], %[[V1]] + %r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1] + %r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1] + return %r0, %r1 : index, index +} + +// CHECK: @min_only_simplify +func.func @min_only_simplify() -> (index, index) { + // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index} + // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index} + // CHECK: affine.min #[[MAP_0]]()[%[[V0]]] + // CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]] + %0 = test.value_with_bounds {max = 512 : index, min = 0 : index} + %1 = test.value_with_bounds {max = 512 : index, min = 256 : index} + %r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1] + %r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1] + return %r0, %r1 : index, index +} + +// CHECK: @max_only_simplify +func.func @max_only_simplify() -> (index, index) { + // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index} + // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index} + // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]] + // CHECK: affine.max #[[MAP_2]]()[%[[V1]]] + %0 = test.value_with_bounds {max = 128 : index, min = 0 : index} + %1 = test.value_with_bounds {max = 512 : index, min = 0 : index} + %r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1] + %r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1] + return %r0, %r1 : index, index +} + +// CHECK: @overlapping_constraints +func.func @overlapping_constraints() -> (index, index) { + %0 = test.value_with_bounds {max = 192 : index, min = 0 : index} + %1 = test.value_with_bounds {max = 384 : index, min = 128 : index} + %2 = test.value_with_bounds {max = 512 : index, min = 256 : index} + // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 192 : index, min = 0 : index} + // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 384 : index, min = 128 : index} + // CHECK: %[[V2:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index} + // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]] + // CHECK: affine.max #[[MAP_1]]()[%[[V1]], %[[V2]]] + %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2] + %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2] + return %r0, %r1 : index, index +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir index bc684b53c9b61..f91eb9c30a51a 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -454,3 +454,38 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// This test checks that by using `simplify_min_max_affine_ops` after padding +// and tiling, it's possible to recover static tiled slices. + +// CHECK-LABEL: @dyn_pad_tiling +// CHECK: %[[LHS:.*]] = tensor.pad +// CHECK: %[[RHS:.*]] = tensor.pad +// CHECK: scf.for +// CHECK-DAG: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32] +// CHECK-DAG: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32] +func.func @dyn_pad_tiling(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %padded, %pad, %copy = transform.structured.pad %0 pad_to_multiple_of [32] use_prescribed_tensor_shapes {padding_dimensions = [2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %tiled_linalg_op, %loops = transform.structured.tile_using_for %padded tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.apply_registered_pass "resolve-shaped-type-result-dims" to %1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %2 { + transform.apply_patterns.canonicalization + } {apply_cse} : !transform.any_op + %3 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.affine.simplify_min_max_affine_ops %3 : !transform.any_op + transform.apply_patterns to %2 { + transform.apply_patterns.canonicalization + } {apply_cse} : !transform.any_op + transform.yield + } +} + diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 78e44c6ec7a9b..6c1a5d3441530 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -836,6 +836,16 @@ void ConversionFuncOp::print(OpAsmPrinter &p) { getArgAttrsAttrName(), getResAttrsAttrName()); } +//===----------------------------------------------------------------------===// +// TestValueWithBoundsOp +//===----------------------------------------------------------------------===// + +void TestValueWithBoundsOp::populateBoundsForIndexValue( + Value v, ValueBoundsConstraintSet &cstr) { + cstr.bound(v) >= getMin().getSExtValue(); + cstr.bound(v) <= getMax().getSExtValue(); +} + //===----------------------------------------------------------------------===// // ReifyBoundOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 30234698bc8dd..8a4981a90831f 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -31,6 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ValueBoundsOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" // Include the attribute definitions. @@ -2375,6 +2376,24 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> { // Test ValueBoundsOpInterface //===----------------------------------------------------------------------===// +def TestValueWithBoundsOp : TEST_Op<"value_with_bounds", [ + DeclareOpInterfaceMethods + ]> { + let description = [{ + Creates a value with specified [min, max] range for value bounds analysis. + + Example: + + ```mlir + %0 = test.value_with_bounds { min = 4 : index, max = 5 : index} + ``` + }]; + let arguments = (ins IndexAttr:$min, IndexAttr:$max); + let results = (outs Index:$result); + let assemblyFormat = "attr-dict"; +} + + def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> { let description = [{ Reify a bound for the given index-typed value or dimension size of a shaped