diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index 33f613a46bad8..96ee7111fea2c 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -357,6 +357,10 @@ struct LevelType { return hasSparseSemantic(); } + constexpr unsigned getNumBuffer() const { + return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1); + } + std::string toMLIRString() const { std::string lvlStr = toFormatString(getLvlFmt()); std::string propStr = ""; diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index 3043a0c4dc410..c9164e39a3a75 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -16,6 +16,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/OneToNTypeConversion.h" //===----------------------------------------------------------------------===// // Include the generated pass header (which needs some early definitions). @@ -143,6 +144,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns); std::unique_ptr createLowerForeachToSCFPass(); +//===----------------------------------------------------------------------===// +// The LowerSparseIterationToSCF pass. +//===----------------------------------------------------------------------===// + +/// Type converter for iter_space and iterator. +struct SparseIterationTypeConverter : public OneToNTypeConverter { + SparseIterationTypeConverter(); +}; + +void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter, + RewritePatternSet &patterns); + +std::unique_ptr createLowerSparseIterationToSCFPass(); + //===----------------------------------------------------------------------===// // The SparseTensorConversion pass. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 196110f55571d..b18c975105b75 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -484,7 +484,7 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> { let summary = "sparse space collapsing pass"; let description = [{ This pass collapses consecutive sparse spaces (extracted from the same tensor) - into one multi-dimensional space. The pass is not yet stablized. + into one multi-dimensional space. The pass is not yet stabilized. }]; let constructor = "mlir::createSparseSpaceCollapsePass()"; let dependentDialects = [ @@ -492,4 +492,19 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> { ]; } +def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> { + let summary = "lower sparse_tensor.iterate/coiterate into scf loops"; + let description = [{ + This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations. + The pass is not yet stabilized. + }]; + let constructor = "mlir::createLowerSparseIterationToSCFPass()"; + let dependentDialects = [ + "memref::MemRefDialect", + "scf::SCFDialect", + "sparse_tensor::SparseTensorDialect", + ]; +} + + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt index 2a29ee8a7a87c..e4acfa8889e5f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms SparseAssembler.cpp SparseBufferRewriting.cpp SparseGPUCodegen.cpp + SparseIterationToScf.cpp SparseReinterpretMap.cpp SparseStorageSpecifierToLLVM.cpp SparseSpaceCollapse.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp new file mode 100644 index 0000000000000..62887c75c872b --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -0,0 +1,78 @@ + +#include "Utils/CodegenUtils.h" +#include "Utils/SparseTensorIterator.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Transforms/OneToNTypeConversion.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, + SmallVectorImpl &fields) { + // Position and coordinate buffer in the sparse structure. + if (enc.getLvlType(lvl).isWithPosLT()) + fields.push_back(enc.getPosMemRefType()); + if (enc.getLvlType(lvl).isWithCrdLT()) + fields.push_back(enc.getCrdMemRefType()); + // One index for shape bound (result from lvlOp). + fields.push_back(IndexType::get(enc.getContext())); +} + +static std::optional +convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl &fields) { + + auto idxTp = IndexType::get(itSp.getContext()); + for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++) + convertLevelType(itSp.getEncoding(), l, fields); + + // Two indices for lower and upper bound (we only need one pair for the last + // iteration space). + fields.append({idxTp, idxTp}); + return success(); +} + +namespace { + +/// Sparse codegen rule for number of entries operator. +class ExtractIterSpaceConverter + : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + LogicalResult + matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + + // Construct the iteration space. + SparseIterationSpace space(loc, rewriter, op.getTensor(), 0, + op.getLvlRange(), adaptor.getParentIter()); + + SmallVector result = space.toValues(); + rewriter.replaceOp(op, result, resultMapping); + return success(); + } +}; + +} // namespace + +mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertIterSpaceType); + + addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp, + ValueRange inputs, + Location loc) -> std::optional { + return builder + .create(loc, TypeRange(spTp), inputs) + .getResult(0); + }); +} + +void mlir::populateLowerSparseIterationToSCFPatterns( + TypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(converter, patterns.getContext()); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index b42d58634a36c..8004bdb904b8a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -26,6 +26,7 @@ namespace mlir { #define GEN_PASS_DEF_SPARSEREINTERPRETMAP #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE #define GEN_PASS_DEF_SPARSIFICATIONPASS +#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH #define GEN_PASS_DEF_LOWERFOREACHTOSCF #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS @@ -157,6 +158,29 @@ struct LowerForeachToSCFPass } }; +struct LowerSparseIterationToSCFPass + : public impl::LowerSparseIterationToSCFBase< + LowerSparseIterationToSCFPass> { + LowerSparseIterationToSCFPass() = default; + LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) = + default; + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + SparseIterationTypeConverter converter; + ConversionTarget target(*ctx); + + // The actual conversion. + target.addIllegalOp(); + populateLowerSparseIterationToSCFPatterns(converter, patterns); + + if (failed(applyPartialOneToNConversion(getOperation(), converter, + std::move(patterns)))) + signalPassFailure(); + } +}; + struct SparseTensorConversionPass : public impl::SparseTensorConversionPassBase { SparseTensorConversionPass() = default; @@ -439,6 +463,10 @@ std::unique_ptr mlir::createLowerForeachToSCFPass() { return std::make_unique(); } +std::unique_ptr mlir::createLowerSparseIterationToSCFPass() { + return std::make_unique(); +} + std::unique_ptr mlir::createSparseTensorConversionPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp index dbec46d2616d9..be8e15d6ae6f4 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp @@ -168,7 +168,7 @@ class CompressedLevel : public SparseLevel { ValueRange posRange = posRangeIf.getResults(); return {posRange.front(), posRange.back()}; } -}; +}; // namespace class LooseCompressedLevel : public SparseLevel { public: @@ -190,7 +190,7 @@ class LooseCompressedLevel : public SparseLevel { Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd); return {pLo, pHi}; } -}; +}; // namespace class SingletonLevel : public SparseLevel { public: @@ -210,6 +210,13 @@ class SingletonLevel : public SparseLevel { // Use the segHi as the loop upper bound. return {p, segHi}; } + + ValuePair + collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix, + std::pair parentRange) const override { + // Singleton level keeps the same range after collapsing. + return parentRange; + }; }; class NOutOfMLevel : public SparseLevel { @@ -1474,10 +1481,85 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) { return getCursor(); } +//===----------------------------------------------------------------------===// +// SparseIterationSpace Implementation +//===----------------------------------------------------------------------===// + +mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace( + Location l, OpBuilder &b, Value t, unsigned tid, + std::pair lvlRange, ValueRange parentPos) + : lvls() { + auto [lvlLo, lvlHi] = lvlRange; + + Value c0 = C_IDX(0); + if (parentPos.empty()) + parentPos = c0; + + for (Level lvl = lvlLo; lvl < lvlHi; lvl++) + lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl)); + + bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos); + for (auto &lvl : getLvlRef().drop_front()) + bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound); +} + +SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues( + IterSpaceType dstTp, ValueRange values, unsigned int tid) { + // Reconstruct every sparse tensor level. + SparseIterationSpace space; + for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) { + unsigned bufferCnt = 0; + if (lt.isWithPosLT()) + bufferCnt++; + if (lt.isWithCrdLT()) + bufferCnt++; + // Sparse tensor buffers. + ValueRange buffers = values.take_front(bufferCnt); + values = values.drop_front(bufferCnt); + + // Level size. + Value sz = values.front(); + values = values.drop_front(); + space.lvls.push_back( + makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl())); + } + // Two bounds. + space.bound = std::make_pair(values[0], values[1]); + values = values.drop_front(2); + + // Must have consumed all values. + assert(values.empty()); + return space; +} + //===----------------------------------------------------------------------===// // SparseIterator factory functions. //===----------------------------------------------------------------------===// +/// Helper function to create a TensorLevel object from given `tensor`. +std::unique_ptr +sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b, + unsigned t, Level l) { + assert(lt.getNumBuffer() == b.size()); + switch (lt.getLvlFmt()) { + case LevelFormat::Dense: + return std::make_unique(t, l, sz); + case LevelFormat::Batch: + return std::make_unique(t, l, sz); + case LevelFormat::Compressed: + return std::make_unique(t, l, lt, sz, b[0], b[1]); + case LevelFormat::LooseCompressed: + return std::make_unique(t, l, lt, sz, b[0], b[1]); + case LevelFormat::Singleton: + return std::make_unique(t, l, lt, sz, b[0]); + case LevelFormat::NOutOfM: + return std::make_unique(t, l, lt, sz, b[0]); + case LevelFormat::Undef: + llvm_unreachable("undefined level format"); + } + llvm_unreachable("unrecognizable level format"); +} + std::unique_ptr sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl) { @@ -1487,33 +1569,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, Value sz = stt.hasEncoding() ? b.create(l, t, lvl).getResult() : b.create(l, t, lvl).getResult(); - switch (lt.getLvlFmt()) { - case LevelFormat::Dense: - return std::make_unique(tid, lvl, sz); - case LevelFormat::Batch: - return std::make_unique(tid, lvl, sz); - case LevelFormat::Compressed: { - Value pos = b.create(l, t, lvl); - Value crd = b.create(l, t, lvl); - return std::make_unique(tid, lvl, lt, sz, pos, crd); - } - case LevelFormat::LooseCompressed: { + SmallVector buffers; + if (lt.isWithPosLT()) { Value pos = b.create(l, t, lvl); - Value crd = b.create(l, t, lvl); - return std::make_unique(tid, lvl, lt, sz, pos, crd); - } - case LevelFormat::Singleton: { - Value crd = b.create(l, t, lvl); - return std::make_unique(tid, lvl, lt, sz, crd); + buffers.push_back(pos); } - case LevelFormat::NOutOfM: { - Value crd = b.create(l, t, lvl); - return std::make_unique(tid, lvl, lt, sz, crd); + if (lt.isWithCrdLT()) { + Value pos = b.create(l, t, lvl); + buffers.push_back(pos); } - case LevelFormat::Undef: - llvm_unreachable("undefined level format"); - } - llvm_unreachable("unrecognizable level format"); + return makeSparseTensorLevel(lt, sz, buffers, tid, lvl); } std::pair, std::unique_ptr> diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 120a806536f19..17636af2b2f9d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -15,6 +15,9 @@ namespace mlir { namespace sparse_tensor { +// Forward declaration. +class SparseIterator; + /// The base class for all types of sparse tensor levels. It provides interfaces /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see /// `peekCrdAt`). @@ -50,6 +53,12 @@ class SparseTensorLevel { peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone = nullptr) const = 0; + virtual std::pair + collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix, + std::pair parentRange) const { + llvm_unreachable("Not Implemented"); + }; + Level getLevel() const { return lvl; } LevelType getLT() const { return lt; } Value getSize() const { return lvlSize; } @@ -62,7 +71,7 @@ class SparseTensorLevel { protected: SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize) - : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){}; + : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize) {}; public: const unsigned tid, lvl; @@ -79,6 +88,55 @@ enum class IterKind : uint8_t { kPad, }; +/// A `SparseIterationSpace` represents a sparse set of coordinates defined by +/// (possibly multiple) levels of a specific sparse tensor. +/// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when +/// feature complete. +class SparseIterationSpace { +public: + SparseIterationSpace() = default; + + // Constructs a N-D iteration space. + SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, + std::pair lvlRange, ValueRange parentPos); + + // Constructs a 1-D iteration space. + SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, + Level lvl, ValueRange parentPos) + : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos) {}; + + bool isUnique() const { return lvls.back()->isUnique(); } + + unsigned getSpaceDim() const { return lvls.size(); } + + // Reconstructs a iteration space directly from the provided ValueRange. + static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, + unsigned tid); + + // The inverse operation of `fromValues`. + SmallVector toValues() const { + SmallVector vals; + for (auto &stl : lvls) { + llvm::append_range(vals, stl->getLvlBuffers()); + vals.push_back(stl->getSize()); + } + vals.append({bound.first, bound.second}); + return vals; + } + + const SparseTensorLevel &getLastLvl() const { return *lvls.back(); } + ArrayRef> getLvlRef() const { + return lvls; + } + + Value getBoundLo() const { return bound.first; } + Value getBoundHi() const { return bound.second; } + +private: + SmallVector> lvls; + std::pair bound; +}; + /// Helper class that generates loop conditions, etc, to traverse a /// sparse tensor level. class SparseIterator { @@ -92,13 +150,13 @@ class SparseIterator { unsigned cursorValsCnt, SmallVectorImpl &cursorValStorage) : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr), - cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage){}; + cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage) {}; SparseIterator(IterKind kind, unsigned cursorValsCnt, SmallVectorImpl &cursorValStorage, const SparseIterator &delegate) : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt, - cursorValStorage){}; + cursorValStorage) {}; SparseIterator(IterKind kind, const SparseIterator &wrap, unsigned extraCursorCnt = 0) @@ -287,10 +345,15 @@ std::unique_ptr makeSparseTensorLevel(OpBuilder &b, unsigned tid, Level lvl); -/// Helper function to create a simple SparseIterator object that iterate over -/// the SparseTensorLevel. -std::unique_ptr makeSimpleIterator(const SparseTensorLevel &stl, - SparseEmitStrategy strategy); +/// Helper function to create a TensorLevel object from given `tensor`. +std::unique_ptr makeSparseTensorLevel(LevelType lt, Value sz, + ValueRange buffers, + unsigned tid, Level l); +/// Helper function to create a simple SparseIterator object that iterates +/// over the SparseTensorLevel. +std::unique_ptr makeSimpleIterator( + const SparseTensorLevel &stl, + SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional); /// Helper function to create a synthetic SparseIterator object that iterates /// over a dense space specified by [0,`sz`). diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir new file mode 100644 index 0000000000000..5fcd661bb69b2 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s + +#COO = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : compressed(nonunique), + j : singleton(soa) + ) +}> + +// CHECK-LABEL: func.func @sparse_1D_space( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor +// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor to memref +// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor to memref +// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref +// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref +// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]] +func.func @sparse_1D_space(%sp : tensor) -> !sparse_tensor.iter_space<#COO, lvls = 0> { + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor -> !sparse_tensor.iter_space<#COO, lvls = 0> + return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0> +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir index baa6199f12bc3..b5d041273f440 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir @@ -8,30 +8,29 @@ }> // CHECK-LABEL: func.func @sparse_sparse_collapse( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>, -// CHECK-SAME: %[[VAL_1:.*]]: index) { -// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse> -// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) -// CHECK: %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index -// CHECK: sparse_tensor.yield %[[VAL_8]] : index +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>) -> index { +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 +// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]]) +// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_2]] : index +// CHECK: sparse_tensor.yield %[[VAL_7]] : index // CHECK: } -// CHECK: "test.sink"(%[[VAL_4]]) : (index) -> () -// CHECK: return +// CHECK: return %[[VAL_4]] : index // CHECK: } -func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) { +func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>) -> index { + %i = arith.constant 0 : index + %c1 = arith.constant 1 : index %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 - : tensor<4x8xf32, #COO> - -> !sparse_tensor.iter_space<#COO, lvls = 0> - %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index { + : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index { %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 - : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> - -> !sparse_tensor.iter_space<#COO, lvls = 1> + : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1> %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index { - %k ="test.op"(%inner) : (index) -> index + %k = arith.addi %inner, %c1 : index sparse_tensor.yield %k : index } sparse_tensor.yield %r2 : index } - "test.sink"(%r1) : (index) -> () - return + return %r1 : index }