Skip to content

[mlir][sparse] implement lowering rules for ExtractIterSpaceOp. #89143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ struct LevelType {
return hasSparseSemantic();
}

constexpr unsigned getNumBuffer() const {
return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
}

std::string toMLIRString() const {
std::string lvlStr = toFormatString(getLvlFmt());
std::string propStr = "";
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/OneToNTypeConversion.h"

//===----------------------------------------------------------------------===//
// Include the generated pass header (which needs some early definitions).
Expand Down Expand Up @@ -143,6 +144,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);

std::unique_ptr<Pass> createLowerForeachToSCFPass();

//===----------------------------------------------------------------------===//
// The LowerSparseIterationToSCF pass.
//===----------------------------------------------------------------------===//

/// Type converter for iter_space and iterator.
struct SparseIterationTypeConverter : public OneToNTypeConverter {
SparseIterationTypeConverter();
};

void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
RewritePatternSet &patterns);

std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();

//===----------------------------------------------------------------------===//
// The SparseTensorConversion pass.
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 16 additions & 1 deletion mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -484,12 +484,27 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
let summary = "sparse space collapsing pass";
let description = [{
This pass collapses consecutive sparse spaces (extracted from the same tensor)
into one multi-dimensional space. The pass is not yet stablized.
into one multi-dimensional space. The pass is not yet stabilized.
}];
let constructor = "mlir::createSparseSpaceCollapsePass()";
let dependentDialects = [
"sparse_tensor::SparseTensorDialect",
];
}

def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
let description = [{
This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
The pass is not yet stabilized.
}];
let constructor = "mlir::createLowerSparseIterationToSCFPass()";
let dependentDialects = [
"memref::MemRefDialect",
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
];
}


#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseAssembler.cpp
SparseBufferRewriting.cpp
SparseGPUCodegen.cpp
SparseIterationToScf.cpp
SparseReinterpretMap.cpp
SparseStorageSpecifierToLLVM.cpp
SparseSpaceCollapse.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@

#include "Utils/CodegenUtils.h"
#include "Utils/SparseTensorIterator.h"

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Transforms/OneToNTypeConversion.h"

using namespace mlir;
using namespace mlir::sparse_tensor;

void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
if (enc.getLvlType(lvl).isWithPosLT())
fields.push_back(enc.getPosMemRefType());
if (enc.getLvlType(lvl).isWithCrdLT())
fields.push_back(enc.getCrdMemRefType());
// One index for shape bound (result from lvlOp).
fields.push_back(IndexType::get(enc.getContext()));
}

static std::optional<LogicalResult>
convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {

auto idxTp = IndexType::get(itSp.getContext());
for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
convertLevelType(itSp.getEncoding(), l, fields);

// Two indices for lower and upper bound (we only need one pair for the last
// iteration space).
fields.append({idxTp, idxTp});
return success();
}

namespace {

/// Sparse codegen rule for number of entries operator.
class ExtractIterSpaceConverter
: public OneToNOpConversionPattern<ExtractIterSpaceOp> {
public:
using OneToNOpConversionPattern::OneToNOpConversionPattern;
LogicalResult
matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();

// Construct the iteration space.
SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
op.getLvlRange(), adaptor.getParentIter());

SmallVector<Value> result = space.toValues();
rewriter.replaceOp(op, result, resultMapping);
return success();
}
};

} // namespace

mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
addConversion([](Type type) { return type; });
addConversion(convertIterSpaceType);

addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
ValueRange inputs,
Location loc) -> std::optional<Value> {
return builder
.create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
.getResult(0);
});
}

void mlir::populateLowerSparseIterationToSCFPatterns(
TypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
}
28 changes: 28 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
Expand Down Expand Up @@ -157,6 +158,29 @@ struct LowerForeachToSCFPass
}
};

struct LowerSparseIterationToSCFPass
: public impl::LowerSparseIterationToSCFBase<
LowerSparseIterationToSCFPass> {
LowerSparseIterationToSCFPass() = default;
LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
default;

void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
SparseIterationTypeConverter converter;
ConversionTarget target(*ctx);

// The actual conversion.
target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
populateLowerSparseIterationToSCFPatterns(converter, patterns);

if (failed(applyPartialOneToNConversion(getOperation(), converter,
std::move(patterns))))
signalPassFailure();
}
};

struct SparseTensorConversionPass
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
SparseTensorConversionPass() = default;
Expand Down Expand Up @@ -439,6 +463,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
return std::make_unique<LowerForeachToSCFPass>();
}

std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
return std::make_unique<LowerSparseIterationToSCFPass>();
}

std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
return std::make_unique<SparseTensorConversionPass>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
ValueRange posRange = posRangeIf.getResults();
return {posRange.front(), posRange.back()};
}
};
}; // namespace

class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
Expand All @@ -190,7 +190,7 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
};
}; // namespace

class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
Expand All @@ -210,6 +210,13 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
// Use the segHi as the loop upper bound.
return {p, segHi};
}

ValuePair
collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
std::pair<Value, Value> parentRange) const override {
// Singleton level keeps the same range after collapsing.
return parentRange;
};
};

class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
Expand Down Expand Up @@ -1474,10 +1481,85 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
return getCursor();
}

//===----------------------------------------------------------------------===//
// SparseIterationSpace Implementation
//===----------------------------------------------------------------------===//

mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
Location l, OpBuilder &b, Value t, unsigned tid,
std::pair<Level, Level> lvlRange, ValueRange parentPos)
: lvls() {
auto [lvlLo, lvlHi] = lvlRange;

Value c0 = C_IDX(0);
if (parentPos.empty())
parentPos = c0;

for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));

bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
for (auto &lvl : getLvlRef().drop_front())
bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
}

SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
IterSpaceType dstTp, ValueRange values, unsigned int tid) {
// Reconstruct every sparse tensor level.
SparseIterationSpace space;
for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
unsigned bufferCnt = 0;
if (lt.isWithPosLT())
bufferCnt++;
if (lt.isWithCrdLT())
bufferCnt++;
// Sparse tensor buffers.
ValueRange buffers = values.take_front(bufferCnt);
values = values.drop_front(bufferCnt);

// Level size.
Value sz = values.front();
values = values.drop_front();
space.lvls.push_back(
makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
}
// Two bounds.
space.bound = std::make_pair(values[0], values[1]);
values = values.drop_front(2);

// Must have consumed all values.
assert(values.empty());
return space;
}

//===----------------------------------------------------------------------===//
// SparseIterator factory functions.
//===----------------------------------------------------------------------===//

/// Helper function to create a TensorLevel object from given `tensor`.
std::unique_ptr<SparseTensorLevel>
sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
unsigned t, Level l) {
assert(lt.getNumBuffer() == b.size());
switch (lt.getLvlFmt()) {
case LevelFormat::Dense:
return std::make_unique<DenseLevel>(t, l, sz);
case LevelFormat::Batch:
return std::make_unique<BatchLevel>(t, l, sz);
case LevelFormat::Compressed:
return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
case LevelFormat::LooseCompressed:
return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
case LevelFormat::Singleton:
return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
case LevelFormat::NOutOfM:
return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
case LevelFormat::Undef:
llvm_unreachable("undefined level format");
}
llvm_unreachable("unrecognizable level format");
}

std::unique_ptr<SparseTensorLevel>
sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
unsigned tid, Level lvl) {
Expand All @@ -1487,33 +1569,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
: b.create<tensor::DimOp>(l, t, lvl).getResult();

switch (lt.getLvlFmt()) {
case LevelFormat::Dense:
return std::make_unique<DenseLevel>(tid, lvl, sz);
case LevelFormat::Batch:
return std::make_unique<BatchLevel>(tid, lvl, sz);
case LevelFormat::Compressed: {
Value pos = b.create<ToPositionsOp>(l, t, lvl);
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
}
case LevelFormat::LooseCompressed: {
SmallVector<Value, 2> buffers;
if (lt.isWithPosLT()) {
Value pos = b.create<ToPositionsOp>(l, t, lvl);
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
}
case LevelFormat::Singleton: {
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
buffers.push_back(pos);
}
case LevelFormat::NOutOfM: {
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
if (lt.isWithCrdLT()) {
Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
buffers.push_back(pos);
}
case LevelFormat::Undef:
llvm_unreachable("undefined level format");
}
llvm_unreachable("unrecognizable level format");
return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
}

std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
Expand Down
Loading
Loading