-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][sparse] implement lowering rules for ExtractIterSpaceOp. #89143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesDO NOT MERGE until #89003 Patch is 62.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89143.diff 19 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 33f613a46bad84..96ee7111fea2cf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -357,6 +357,10 @@ struct LevelType {
return hasSparseSemantic();
}
+ constexpr unsigned getNumBuffer() const {
+ return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
+ }
+
std::string toMLIRString() const {
std::string lvlStr = toFormatString(getLvlFmt());
std::string propStr = "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..081a9b8cad8d62 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,9 +17,13 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/bit.h"
+
//===----------------------------------------------------------------------===//
//
// Type aliases to help code be more self-documenting. Unfortunately
@@ -41,6 +45,40 @@ using Level = uint64_t;
/// including the value `ShapedType::kDynamic` (for shapes).
using Size = int64_t;
+/// A simple wrapper to encode a bitset of defined (at most 64) levels.
+class LevelSet {
+ uint64_t bits = 0;
+
+public:
+ LevelSet() = default;
+ explicit LevelSet(uint64_t bits) : bits(bits) {}
+ operator uint64_t() const { return bits; }
+
+ LevelSet &set(unsigned i) {
+ assert(i < 64);
+ bits |= 1 << i;
+ return *this;
+ }
+
+ LevelSet &operator|=(LevelSet lhs) {
+ bits |= static_cast<uint64_t>(lhs);
+ return *this;
+ }
+
+ LevelSet &lshift(unsigned offset) {
+ bits = bits << offset;
+ return *this;
+ }
+
+ bool operator[](unsigned i) const {
+ assert(i < 64);
+ return (bits & (1 << i)) != 0;
+ }
+
+ unsigned count() const { return llvm::popcount(bits); }
+ bool empty() const { return bits == 0; }
+};
+
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 4a9b9169ae4b86..d5398a98f5b171 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
list<Trait> traits = []>
: AttrDef<SparseTensor_Dialect, name, traits>;
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+ TypedAttrBase<
+ I64, "IntegerAttr",
+ And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+ CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+ "LevelSet attribute"> {
+ let returnType = [{::mlir::sparse_tensor::LevelSet}];
+ let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
//===----------------------------------------------------------------------===//
// These attributes are just like `IndexAttr` except that they clarify whether
// the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4e4441c640ed95..08140b9d2b6192 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
//===----------------------------------------------------------------------===//
// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
- "ForeachOp"]>]> {
+ "ForeachOp", "IterateOp"]>]> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1513,6 +1515,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
let hasVerifier = 1;
}
+def IterateOp : SparseTensor_Op<"iterate",
+ [RecursiveMemoryEffects, RecursivelySpeculatable,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+ "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
+
+ let arguments = (ins AnySparseIterSpace:$iterSpace,
+ Variadic<AnyType>:$initArgs,
+ LevelSetAttr:$crdUsedLvls);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$region);
+
+ let summary = "Iterate over a sparse iteration space";
+ let description = [{
+ The `sparse_tensor.iterate` operations represents a loop over the
+ provided iteration space extracted from a specific sparse tensor.
+ The operation defines an SSA value for a sparse iterator that points
+ to the current stored element in the sparse tensor and SSA values
+ for coordinates of the stored element. The coordinates are always
+ converted to `index` type despite of the underlying sparse tensor
+ storage. When coordinates are not used, the SSA values can be skipped
+ by `_` symbols, which usually leads to simpler generated code after
+ sparsification. For example:
+
+ ```mlir
+ // The coordinate for level 0 is not used when iterating over a 2-D
+ // iteration space.
+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+ ```
+
+ `sparse_tensor.iterate` can also operate on loop-carried variables
+ and returns the final values after loop termination.
+ The initial values of the variables are passed as additional SSA operands
+ to the iterator SSA value and used coordinate SSA values mentioned
+ above. The operation region has an argument for the iterator, variadic
+ arguments for specified (used) coordiates and followed by one argument
+ for each loop-carried variable, representing the value of the variable
+ at the current iteration.
+ The body region must contain exactly one block that terminates with
+ `sparse_tensor.yield`.
+
+ `sparse_tensor.iterate` results hold the final values after the last
+ iteration. If the `sparse_tensor.iterate` defines any values, a yield
+ must be explicitly present.
+ The number and types of the `sparse_tensor.iterate` results must match
+ the initial values in the iter_args binding and the yield operands.
+
+
+ A nested `sparse_tensor.iterate` example that prints all the coordinates
+ stored in the sparse input:
+
+ ```mlir
+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
+ // Iterates over the first level of %sp
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
+ // Iterates over the second level of %sp
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+ %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
+ vector.print %crd0 : index
+ vector.print %crd1 : index
+ }
+ }
+ }
+
+ ```
+ }];
+
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getIterSpace().getType().getSpaceDim();
+ }
+ BlockArgument getIterator() {
+ return getRegion().getArguments().front();
+ }
+ Block::BlockArgListType getCrds() {
+ // The first block argument is iterator, the remaining arguments are
+ // referenced coordinates.
+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ }
+ unsigned getNumRegionIterArgs() {
+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
+ }
+ }];
+
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d6d038ef65bdf4..c9164e39a3a75e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -16,6 +16,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
//===----------------------------------------------------------------------===//
// Include the generated pass header (which needs some early definitions).
@@ -143,6 +144,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
std::unique_ptr<Pass> createLowerForeachToSCFPass();
+//===----------------------------------------------------------------------===//
+// The LowerSparseIterationToSCF pass.
+//===----------------------------------------------------------------------===//
+
+/// Type converter for iter_space and iterator.
+struct SparseIterationTypeConverter : public OneToNTypeConverter {
+ SparseIterationTypeConverter();
+};
+
+void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
+
//===----------------------------------------------------------------------===//
// The SparseTensorConversion pass.
//===----------------------------------------------------------------------===//
@@ -248,6 +263,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
bool enableBufferInitialization, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 4706d5ba2f218c..6f252b99d02c84 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -463,4 +463,35 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
];
}
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+ let summary = "(experimental) sparse space collpasing pass";
+ let description = [{
+ This pass collapse consecutive sparse spaces (extracted from the same tensor)
+ into one multi-dimensional space.
+ }];
+ let constructor = "mlir::createSparseSpaceCollapsePass()";
+ let dependentDialects = [
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
+def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
+ let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
+ let description = [{
+ This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
+ The pass is not yet stablized.
+ }];
+ let constructor = "mlir::createLowerSparseIterationToSCFPass()";
+ let dependentDialects = [
+ "memref::MemRefDialect",
+ "scf::SCFDialect",
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 516b0943bdcfac..36908def09f403 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2027,6 +2027,106 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
printLevelRange(p, lo, hi);
}
+static ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+ SmallVectorImpl<OpAsmParser::Argument> &iterators,
+ SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+ SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+ SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
+ // Parses "%iters, ... in %spaces, ..."
+ if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+ parser.parseOperandList(spaces))
+ return failure();
+
+ if (iterators.size() != spaces.size())
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of sparse iterators and sparse spaces");
+
+ // Parse "at(%crd0, _, ...)"
+ LevelSet crdUsedLvlSet;
+ bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+ unsigned lvlCrdCnt = 0;
+ if (hasUsedCrds) {
+ ParseResult crdList = parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+ if (parser.parseOptionalKeyword("_")) {
+ if (parser.parseArgument(iterArgs.emplace_back()))
+ return failure();
+ // Always use IndexType for the coordinate.
+ crdUsedLvlSet.set(lvlCrdCnt);
+ iterArgs.back().type = parser.getBuilder().getIndexType();
+ }
+ lvlCrdCnt += 1;
+ return success();
+ });
+ if (failed(crdList)) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expecting SSA value or \"_\" for level coordinates");
+ }
+ }
+ // Set the CrdUsedLvl bitset.
+ state.addAttribute("crdUsedLvls",
+ parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+ // Parse "iter_args(%arg = %init, ...)"
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(iterArgs, initArgs))
+ return failure();
+
+ SmallVector<Type> iterSpaceTps;
+ // parse ": sparse_tensor.iter_space -> ret"
+ if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+ return failure();
+ if (iterSpaceTps.size() != spaces.size())
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space operands "
+ "and iteration space types");
+
+ for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+ IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+ if (!spaceTp)
+ return parser.emitError(parser.getNameLoc(),
+ "expected sparse_tensor.iter_space type for "
+ "iteration space operands");
+ if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space dimension "
+ "and specified coordinates");
+ it.type = spaceTp.getIteratorType();
+ }
+
+ if (hasIterArgs)
+ if (parser.parseArrowTypeList(state.types))
+ return failure();
+
+ // Resolves input operands.
+ if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+ state.operands))
+ return failure();
+
+ if (hasIterArgs) {
+ unsigned numCrds = crdUsedLvlSet.count();
+ // Strip off leading args that used for coordinates.
+ MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+ if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of iteration arguments and return values");
+ }
+
+ for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+ it.type = tp;
+ if (parser.resolveOperand(init, tp, state.operands))
+ return failure();
+ }
+ }
+ return success();
+}
+
LogicalResult ExtractIterSpaceOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
@@ -2063,6 +2163,161 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::Argument iterator;
+ OpAsmParser::UnresolvedOperand iterSpace;
+
+ SmallVector<OpAsmParser::Argument> iters, iterArgs;
+ if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+ return failure();
+ if (iters.size() != 1)
+ return parser.emitError(parser.getNameLoc(),
+ "expected only one iterator/iteration space");
+
+ iters.append(iterArgs);
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, iters))
+ return failure();
+
+ IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+/// Prints the initialization list in the form of
+/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ p << prefix << '(';
+ llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+ p << std::get<0>(it) << " = " << std::get<1>(it);
+ });
+ p << ")";
+}
+
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+ Block::BlockArgListType blocksArgs,
+ LevelSet crdUsedLvls) {
+ if (crdUsedLvls.empty())
+ return;
+
+ p << " at(";
+ for (unsigned i = 0; i < spaceDim; i++) {
+ if (crdUsedLvls[i]) {
+ p << blocksArgs.front();
+ blocksArgs = blocksArgs.drop_front();
+ } else {
+ p << "_";
+ }
+ if (i != spaceDim - 1)
+ p << ", ";
+ }
+ assert(blocksArgs.empty());
+ p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+ p << " " << getIterator() << " in " << getIterSpace();
+ printUsedCrdsList(p, getSpaceDim(), getC...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in general but I haven't understood all of the details. This particular op is probably good to start implementing because it is essentially a no-op, so you can concentrate on bringing all the glue code in place, but I am lacking some familiarity with the remainder of the code to really understand that.
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
c5c7bb1
to
bfc234a
Compare
bfc234a
to
9e21f22
Compare
mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
Outdated
Show resolved
Hide resolved
eb7be23
to
90e5f1f
Compare
5ebd020
to
38b3625
Compare
DO NOT MERGE until #89003