diff --git a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h new file mode 100644 index 0000000000000..31e19ff1ad39f --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h @@ -0,0 +1,104 @@ +//===- ScalableValueBoundsConstraintSet.h - Scalable Value Bounds ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H +#define MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H + +#include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +namespace mlir::vector { + +namespace detail { + +/// Parent class for the value bounds RTTIExtends. Uses protected inheritance to +/// hide all ValueBoundsConstraintSet methods by default (as some do not use the +/// ScalableValueBoundsConstraintSet, so may produce unexpected results). +struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet { + using ::mlir::ValueBoundsConstraintSet::ValueBoundsConstraintSet; +}; +} // namespace detail + +/// A version of `ValueBoundsConstraintSet` that can solve for scalable bounds. +struct ScalableValueBoundsConstraintSet + : public llvm::RTTIExtends { + ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin, + unsigned vscaleMax) + : RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){}; + + using RTTIExtends::bound; + using RTTIExtends::StopConditionFn; + + /// A thin wrapper over an `AffineMap` which can represent a constant bound, + /// or a scalable bound (in terms of vscale). The `AffineMap` will always + /// take at most one parameter, vscale, and returns a single result, which is + /// the bound of value. + struct ConstantOrScalableBound { + AffineMap map; + + struct BoundSize { + int64_t baseSize{0}; + bool scalable{false}; + }; + + /// Get the (possibly) scalable size of the bound, returns failure if + /// the bound cannot be represented as a single quantity. + FailureOr getSize() const; + }; + + /// Computes a (possibly) scalable bound for a given value. This is + /// similar to `ValueBoundsConstraintSet::computeConstantBound()`, but + /// uses knowledge of the range of vscale to compute either a constant + /// bound, an expression in terms of vscale, or failure if no bound can + /// be computed. + /// + /// The resulting `AffineMap` will always take at most one parameter, + /// vscale, and return a single result, which is the bound of `value`. + /// + /// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` == + /// `vscaleMax`, the resulting bound (if found), will be constant. + static FailureOr + computeScalableBound(Value value, std::optional dim, + unsigned vscaleMin, unsigned vscaleMax, + presburger::BoundType boundType, bool closedUB = true, + StopConditionFn stopCondition = nullptr); + + /// Get the value of vscale. Returns `nullptr` vscale as not been encountered. + Value getVscaleValue() const { return vscale; } + + /// Sets the value of vscale. Asserts if vscale has already been set. + void setVscale(vector::VectorScaleOp vscaleOp) { + assert(!vscale && "expected vscale to be unset"); + vscale = vscaleOp.getResult(); + } + + /// The minimum possible value of vscale. + unsigned getVscaleMin() const { return vscaleMin; } + + /// The maximum possible value of vscale. + unsigned getVscaleMax() const { return vscaleMax; } + + static char ID; + +private: + const unsigned vscaleMin; + const unsigned vscaleMax; + + // This will be set when the first `vector.vscale` operation is found within + // the `ValueBoundsOpInterface` implementation then reused from there on. + Value vscale = nullptr; +}; + +using ConstantOrScalableBound = + ScalableValueBoundsConstraintSet::ConstantOrScalableBound; + +} // namespace mlir::vector + +#endif // MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H diff --git a/mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h new file mode 100644 index 0000000000000..4794bc9016c6f --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H +#define MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace vector { +void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace vector +} // namespace mlir + +#endif // MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 21775e11e0714..9bbf12d132540 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -82,6 +82,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" #include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" @@ -174,6 +175,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { tosa::registerShardingInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); vector::registerSubsetOpInterfaceExternalModels(registry); + vector::registerValueBoundsOpInterfaceExternalModels(registry); NVVM::registerNVVMTargetInterfaceExternalModels(registry); ROCDL::registerROCDLTargetInterfaceExternalModels(registry); spirv::registerSPIRVTargetInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 28dadfb9ecf86..b4ed0967e63f1 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -15,6 +15,7 @@ #include "mlir/IR/Value.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/ExtensibleRTTI.h" #include @@ -63,7 +64,8 @@ using ValueDimList = SmallVector>>; /// /// Note: Any modification of existing IR invalides the data stored in this /// class. Adding new operations is allowed. -class ValueBoundsConstraintSet { +class ValueBoundsConstraintSet + : public llvm::RTTIExtends { protected: /// Helper class that builds a bound for a shaped value dimension or /// index-typed value. @@ -107,6 +109,8 @@ class ValueBoundsConstraintSet { }; public: + static char ID; + /// The stop condition when traversing the backward slice of a shaped value/ /// index-type value. The traversal continues until the stop condition /// evaluates to "true" for a value. @@ -265,6 +269,16 @@ class ValueBoundsConstraintSet { ValueBoundsConstraintSet(MLIRContext *ctx); + /// Populates the constraint set for a value/map without actually computing + /// the bound. Returns the position for the value/map (via the return value + /// and `posOut` output parameter). + int64_t populateConstraintsSet(Value value, + std::optional dim = std::nullopt, + StopConditionFn stopCondition = nullptr); + int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands, + StopConditionFn stopCondition = nullptr, + int64_t *posOut = nullptr); + /// Iteratively process all elements on the worklist until an index-typed /// value or shaped value meets `stopCondition`. Such values are not processed /// any further. diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt index 70f3fa8c297d4..204462ffd047c 100644 --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRVectorDialect VectorOps.cpp + ValueBoundsOpInterfaceImpl.cpp + ScalableValueBoundsConstraintSet.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/IR diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp new file mode 100644 index 0000000000000..6d7e3bc70f59d --- /dev/null +++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp @@ -0,0 +1,103 @@ +//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" + +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +namespace mlir::vector { + +FailureOr +ConstantOrScalableBound::getSize() const { + if (map.isSingleConstant()) + return BoundSize{map.getSingleConstantResult(), /*scalable=*/false}; + if (map.getNumResults() != 1 || map.getNumInputs() != 1) + return failure(); + auto binop = dyn_cast(map.getResult(0)); + if (!binop || binop.getKind() != AffineExprKind::Mul) + return failure(); + auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool { + if (auto cst = dyn_cast(expr)) { + constant = cst.getValue(); + return true; + } + return false; + }; + // Match `s0 * cst` or `cst * s0`: + int64_t cst = 0; + auto lhs = binop.getLHS(); + auto rhs = binop.getRHS(); + if ((matchConstant(lhs, cst) && isa(rhs)) || + (matchConstant(rhs, cst) && isa(lhs))) { + return BoundSize{cst, /*scalable=*/true}; + } + return failure(); +} + +char ScalableValueBoundsConstraintSet::ID = 0; + +FailureOr +ScalableValueBoundsConstraintSet::computeScalableBound( + Value value, std::optional dim, unsigned vscaleMin, + unsigned vscaleMax, presburger::BoundType boundType, bool closedUB, + StopConditionFn stopCondition) { + using namespace presburger; + + assert(vscaleMin <= vscaleMax); + ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin, + vscaleMax); + + int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition); + + // Project out all variables apart from vscale. + // This should result in constraints in terms of vscale only. + scalableCstr.projectOut( + [&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); }); + + assert(scalableCstr.cstr.getNumDimAndSymbolVars() == + scalableCstr.positionToValueDim.size() && + "inconsistent mapping state"); + + // Check that the only symbols left are vscale. + for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) { + if (i == pos) + continue; + if (scalableCstr.positionToValueDim[i] != + ValueDim(scalableCstr.getVscaleValue(), + ValueBoundsConstraintSet::kIndexValue)) { + return failure(); + } + } + + SmallVector lowerBound(1), upperBound(1); + scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound, + &upperBound, closedUB); + + auto invalidBound = [](auto &bound) { + return !bound[0] || bound[0].getNumResults() != 1; + }; + + AffineMap bound = [&] { + if (boundType == BoundType::EQ && !invalidBound(lowerBound) && + lowerBound[0] == lowerBound[0]) { + return lowerBound[0]; + } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) { + return lowerBound[0]; + } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { + return upperBound[0]; + } + return AffineMap{}; + }(); + + if (!bound) + return failure(); + + return ConstantOrScalableBound{bound}; +} + +} // namespace mlir::vector diff --git a/mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp new file mode 100644 index 0000000000000..ca95072d9bb0f --- /dev/null +++ b/mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp @@ -0,0 +1,51 @@ +//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h" + +#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; + +namespace mlir::vector { +namespace { + +struct VectorScaleOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto *scalableCstr = dyn_cast(&cstr); + if (!scalableCstr) + return; + auto vscaleOp = cast(op); + assert(value == vscaleOp.getResult() && "invalid value"); + if (auto vscale = scalableCstr->getVscaleValue()) { + // All copies of vscale are equivalent. + scalableCstr->bound(value) == cstr.getExpr(vscale); + } else { + // We know vscale is confined to [vscaleMin, vscaleMax]. + scalableCstr->bound(value) >= scalableCstr->getVscaleMin(); + scalableCstr->bound(value) <= scalableCstr->getVscaleMax(); + scalableCstr->setVscale(vscaleOp); + } + } +}; + +} // namespace +} // namespace mlir::vector + +void mlir::vector::registerValueBoundsOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { + vector::VectorScaleOp::attachInterface( + *ctx); + }); +} diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index 85abc2df89479..06ec3f4e135e9 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -70,6 +70,8 @@ static std::optional getConstantIntValue(OpFoldResult ofr) { ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx) : builder(ctx) {} +char ValueBoundsConstraintSet::ID = 0; + #ifndef NDEBUG static void assertValidValueDim(Value value, std::optional dim) { if (value.getType().isIndex()) { @@ -471,55 +473,87 @@ FailureOr ValueBoundsConstraintSet::computeConstantBound( closedUB); } +FailureOr ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType type, AffineMap map, ArrayRef operands, + StopConditionFn stopCondition, bool closedUB) { + ValueDimList valueDims; + for (Value v : operands) { + assert(v.getType().isIndex() && "expected index type"); + valueDims.emplace_back(v, std::nullopt); + } + return computeConstantBound(type, map, valueDims, stopCondition, closedUB); +} + FailureOr ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType type, AffineMap map, ValueDimList operands, StopConditionFn stopCondition, bool closedUB) { assert(map.getNumResults() == 1 && "expected affine map with one result"); ValueBoundsConstraintSet cstr(map.getContext()); - int64_t pos = cstr.insert(/*isSymbol=*/false); + + int64_t pos = 0; + if (stopCondition) { + cstr.populateConstraintsSet(map, operands, stopCondition, &pos); + } else { + // No stop condition specified: Keep adding constraints until a bound could + // be computed. + cstr.populateConstraintsSet( + map, operands, + [&](Value v, std::optional dim) { + return cstr.cstr.getConstantBound64(type, pos).has_value(); + }, + &pos); + } + // Compute constant bound for `valueDim`. + int64_t ubAdjustment = closedUB ? 0 : 1; + if (auto bound = cstr.cstr.getConstantBound64(type, pos)) + return type == BoundType::UB ? *bound + ubAdjustment : *bound; + return failure(); +} + +int64_t ValueBoundsConstraintSet::populateConstraintsSet( + Value value, std::optional dim, StopConditionFn stopCondition) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + + AffineMap map = + AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, + Builder(value.getContext()).getAffineDimExpr(0)); + return populateConstraintsSet(map, {{value, dim}}, stopCondition); +} + +int64_t ValueBoundsConstraintSet::populateConstraintsSet( + AffineMap map, ValueDimList operands, StopConditionFn stopCondition, + int64_t *posOut) { + assert(map.getNumResults() == 1 && "expected affine map with one result"); + int64_t pos = insert(/*isSymbol=*/false); + if (posOut) + *posOut = pos; // Add map and operands to the constraint set. Dimensions are converted to // symbols. All operands are added to the worklist. auto mapper = [&](std::pair> v) { - return cstr.getExpr(v.first, v.second); + return getExpr(v.first, v.second); }; SmallVector dimReplacements = llvm::to_vector( llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper)); SmallVector symReplacements = llvm::to_vector( llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper)); - cstr.addBound( + addBound( presburger::BoundType::EQ, pos, map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements)); // Process the backward slice of `operands` (i.e., reverse use-def chain) // until `stopCondition` is met. if (stopCondition) { - cstr.processWorklist(stopCondition); + processWorklist(stopCondition); } else { - // No stop condition specified: Keep adding constraints until a bound could - // be computed. - cstr.processWorklist( - /*stopCondition=*/[&](Value v, std::optional dim) { - return cstr.cstr.getConstantBound64(type, pos).has_value(); - }); + // No stop condition specified: Keep adding constraints until the worklist + // is empty. + processWorklist([](Value v, std::optional dim) { return false; }); } - // Compute constant bound for `valueDim`. - int64_t ubAdjustment = closedUB ? 0 : 1; - if (auto bound = cstr.cstr.getConstantBound64(type, pos)) - return type == BoundType::UB ? *bound + ubAdjustment : *bound; - return failure(); -} - -FailureOr ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType type, AffineMap map, ArrayRef operands, - StopConditionFn stopCondition, bool closedUB) { - ValueDimList valueDims; - for (Value v : operands) { - assert(v.getType().isIndex() && "expected index type"); - valueDims.emplace_back(v, std::nullopt); - } - return computeConstantBound(type, map, valueDims, stopCondition, closedUB); + return pos; } FailureOr diff --git a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir new file mode 100644 index 0000000000000..245a6f5c13ac3 --- /dev/null +++ b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt %s -test-affine-reify-value-bounds -cse -verify-diagnostics \ +// RUN: -verify-diagnostics -split-input-file | FileCheck %s + +#map_dim_i = affine_map<(d0)[s0] -> (-d0 + 32400, s0)> +#map_dim_j = affine_map<(d0)[s0] -> (-d0 + 16, s0)> + +// Here the upper bound for min_i is 4 x vscale, as we know 4 x vscale is +// always less than 32400. The bound for min_j is 16, as 16 is always less +// 4 x vscale_max (vscale_max is the UB for vscale). + +// CHECK: #[[$SCALABLE_BOUND_MAP_0:.*]] = affine_map<()[s0] -> (s0 * 4)> + +// CHECK-LABEL: @fixed_size_loop_nest +// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale +// CHECK-DAG: %[[UB_i:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_0]]()[%[[VSCALE]]] +// CHECK-DAG: %[[UB_j:.*]] = arith.constant 16 : index +// CHECK: "test.some_use"(%[[UB_i]], %[[UB_j]]) : (index, index) -> () +func.func @fixed_size_loop_nest() { + %c16 = arith.constant 16 : index + %c32400 = arith.constant 32400 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %vscale = vector.vscale + %c4_vscale = arith.muli %vscale, %c4 : index + scf.for %i = %c0 to %c32400 step %c4_vscale { + %min_i = affine.min #map_dim_i(%i)[%c4_vscale] + scf.for %j = %c0 to %c16 step %c4_vscale { + %min_j = affine.min #map_dim_j(%j)[%c4_vscale] + %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index + %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index + "test.some_use"(%bound_i, %bound_j) : (index, index) -> () + } + } + return +} + +// ----- + +#map_dynamic_dim = affine_map<(d0)[s0, s1] -> (-d0 + s1, s0)> + +// Here upper bounds for both min_i and min_j are both (conservatively) +// 4 x vscale, as we know that is always the largest value they could take. As +// if `dim < 4 x vscale` then 4 x vscale is an overestimate, and if +// `dim > 4 x vscale` then the min will be clamped to 4 x vscale. + +// CHECK: #[[$SCALABLE_BOUND_MAP_1:.*]] = affine_map<()[s0] -> (s0 * 4)> + +// CHECK-LABEL: @dynamic_size_loop_nest +// CHECK: %[[VSCALE:.*]] = vector.vscale +// CHECK: %[[UB_ij:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_1]]()[%[[VSCALE]]] +// CHECK: "test.some_use"(%[[UB_ij]], %[[UB_ij]]) : (index, index) -> () +func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) { + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %vscale = vector.vscale + %c4_vscale = arith.muli %vscale, %c4 : index + scf.for %i = %c0 to %dim0 step %c4_vscale { + %min_i = affine.min #map_dynamic_dim(%i)[%c4_vscale, %dim0] + scf.for %j = %c0 to %dim1 step %c4_vscale { + %min_j = affine.min #map_dynamic_dim(%j)[%c4_vscale, %dim1] + %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index + %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index + "test.some_use"(%bound_i, %bound_j) : (index, index) -> () + } + } + return +} + +// ----- + +// Here the bound is just a value + a constant. + +// CHECK: #[[$SCALABLE_BOUND_MAP_2:.*]] = affine_map<()[s0] -> (s0 + 8)> + +// CHECK-LABEL: @add_to_vscale +// CHECK: %[[VSCALE:.*]] = vector.vscale +// CHECK: %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_2]]()[%[[VSCALE]]] +// CHECK: "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> () +func.func @add_to_vscale() { + %vscale = vector.vscale + %c8 = arith.constant 8 : index + %vscale_plus_c8 = arith.addi %vscale, %c8 : index + %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index + "test.some_use"(%bound) : (index) -> () + return +} + +// ----- + +// Here we know vscale is always 2 so we get a constant bound. + +// CHECK-LABEL: @vscale_fixed_size +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: "test.some_use"(%[[C2]]) : (index) -> () +func.func @vscale_fixed_size() { + %vscale = vector.vscale + %bound = "test.reify_scalable_bound"(%vscale) {type = "EQ", vscale_min = 2, vscale_max = 2} : (index) -> index + "test.some_use"(%bound) : (index) -> () + return +} + +// ----- + +// Here we don't know the upper bound (%a is underspecified) + +func.func @unknown_bound(%a: index) { + %vscale = vector.vscale + %vscale_plus_a = arith.muli %vscale, %a : index + // expected-error @below{{could not reify bound}} + %bound = "test.reify_scalable_bound"(%vscale_plus_a) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index + "test.some_use"(%bound) : (index) -> () + return +} + +// ----- + +// Here we have two vscale values (that have not been CSE'd), but they should +// still be treated as equivalent. + +// CHECK: #[[$SCALABLE_BOUND_MAP_3:.*]] = affine_map<()[s0] -> (s0 * 6)> + +// CHECK-LABEL: @duplicate_vscale_values +// CHECK: %[[VSCALE:.*]] = vector.vscale +// CHECK: %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_3]]()[%[[VSCALE]]] +// CHECK: "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> () +func.func @duplicate_vscale_values() { + %c4 = arith.constant 4 : index + %vscale_0 = vector.vscale + + %c2 = arith.constant 2 : index + %vscale_1 = vector.vscale + + %c4_vscale = arith.muli %vscale_0, %c4 : index + %c2_vscale = arith.muli %vscale_1, %c2 : index + %add = arith.addi %c2_vscale, %c4_vscale : index + + %bound = "test.reify_scalable_bound"(%add) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index + "test.some_use"(%bound) : (index) -> () + return +} + +// ----- + +// Test some non-scalable code to ensure that works too: + +#map_dim_i = affine_map<(d0)[s0] -> (-d0 + 1024, s0)> + +// CHECK-LABEL: @non_scalable_code +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: "test.some_use"(%[[C4]]) : (index) -> () +func.func @non_scalable_code() { + %c1024 = arith.constant 1024 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + scf.for %i = %c0 to %c1024 step %c4 { + %min_i = affine.min #map_dim_i(%i)[%c4] + %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index + "test.some_use"(%bound_i) : (index) -> () + } + return +} diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 39671a930f2e2..5e160b720db62 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Pass/Pass.h" @@ -75,7 +76,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, WalkResult result = funcOp.walk([&](Operation *op) { // Look for test.reify_bound ops. if (op->getName().getStringRef() == "test.reify_bound" || - op->getName().getStringRef() == "test.reify_constant_bound") { + op->getName().getStringRef() == "test.reify_constant_bound" || + op->getName().getStringRef() == "test.reify_scalable_bound") { if (op->getNumOperands() != 1 || op->getNumResults() != 1 || !op->getResultTypes()[0].isIndex()) { op->emitOpError("invalid op"); @@ -110,6 +112,9 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, bool constant = op->getName().getStringRef() == "test.reify_constant_bound"; + bool scalable = !constant && op->getName().getStringRef() == + "test.reify_scalable_bound"; + // Prepare stop condition. By default, reify in terms of the op's // operands. No stop condition is used when a constant was requested. std::function)> stopCondition = @@ -137,6 +142,37 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, if (succeeded(reifiedConst)) reified = FailureOr(rewriter.getIndexAttr(*reifiedConst)); + } else if (scalable) { + unsigned vscaleMin = 0; + unsigned vscaleMax = 0; + if (auto attr = "vscale_min"; op->hasAttrOfType(attr)) { + vscaleMin = unsigned(op->getAttrOfType(attr).getInt()); + } else { + op->emitOpError("expected `vscale_min` to be provided"); + return WalkResult::skip(); + } + if (auto attr = "vscale_max"; op->hasAttrOfType(attr)) { + vscaleMax = unsigned(op->getAttrOfType(attr).getInt()); + } else { + op->emitOpError("expected `vscale_max` to be provided"); + return WalkResult::skip(); + } + + auto loc = op->getLoc(); + auto reifiedScalable = + vector::ScalableValueBoundsConstraintSet::computeScalableBound( + value, dim, vscaleMin, vscaleMax, *boundType); + if (succeeded(reifiedScalable)) { + SmallVector>, 1> + vscaleOperand; + if (reifiedScalable->map.getNumInputs() == 1) { + // The only possible input to the bound is vscale. + vscaleOperand.push_back(std::make_pair( + rewriter.create(loc), std::nullopt)); + } + reified = affine::materializeComputedBound( + rewriter, loc, reifiedScalable->map, vscaleOperand); + } } else { if (dim) { if (useArithOps) {