Skip to content

Commit 6af6167

Browse files
author
Peiming Liu
committed
[mlir][sparse] implement sparse space collapse pass.
1 parent e276cf0 commit 6af6167

File tree

5 files changed

+239
-0
lines changed

5 files changed

+239
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
248248
bool enableBufferInitialization, unsigned vectorLength,
249249
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
250250

251+
//===----------------------------------------------------------------------===//
252+
// Sparse Iteration Transform Passes
253+
//===----------------------------------------------------------------------===//
254+
255+
std::unique_ptr<Pass> createSparseSpaceCollapsePass();
256+
251257
//===----------------------------------------------------------------------===//
252258
// Registration.
253259
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,4 +464,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
464464
];
465465
}
466466

467+
//===----------------------------------------------------------------------===//
468+
// Sparse Iteration Transform Passes
469+
//===----------------------------------------------------------------------===//
470+
471+
def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
472+
let summary = "(experimental) sparse space collpasing pass";
473+
let description = [{
474+
This pass collapse consecutive sparse spaces (extracted from the same tensor)
475+
into one multi-dimensional space. The pass is not yet stablized.
476+
}];
477+
let constructor = "mlir::createSparseSpaceCollapsePass()";
478+
let dependentDialects = [
479+
"sparse_tensor::SparseTensorDialect",
480+
];
481+
}
482+
467483
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
66
SparseGPUCodegen.cpp
77
SparseReinterpretMap.cpp
88
SparseStorageSpecifierToLLVM.cpp
9+
SparseSpaceCollapse.cpp
910
SparseTensorCodegen.cpp
1011
SparseTensorConversion.cpp
1112
SparseTensorPasses.cpp
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Func/IR/FuncOps.h"
10+
#include "mlir/IR/IRMapping.h"
11+
#include "mlir/Transforms/Passes.h"
12+
13+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15+
16+
namespace mlir {
17+
18+
#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
19+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
20+
21+
namespace sparse_tensor {
22+
23+
struct CollapseSpaceInfo {
24+
ExtractIterSpaceOp space;
25+
IterateOp loop;
26+
};
27+
28+
bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
29+
auto pIterArgs = parent.getRegionIterArgs();
30+
auto nInitArgs = node.getInits();
31+
if (pIterArgs.size() != nInitArgs.size())
32+
return false;
33+
34+
// Two loops are collapsable if they are perfectly nested.
35+
auto pYields = parent.getYieldedValues();
36+
auto nResult = node.getLoopResults().value();
37+
38+
bool yieldEq =
39+
llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
40+
return std::get<0>(zipped) == std::get<1>(zipped);
41+
});
42+
43+
// Parent iter_args should be passed directly to the node's init_args.
44+
bool iterArgEq =
45+
llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
46+
return std::get<0>(zipped) == std::get<1>(zipped);
47+
});
48+
49+
return yieldEq && iterArgEq;
50+
}
51+
52+
bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
53+
ExtractIterSpaceOp curSpace) {
54+
55+
auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
56+
Value spaceVal = space.getResultSpace();
57+
if (spaceVal.hasOneUse())
58+
return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
59+
return nullptr;
60+
};
61+
62+
if (toCollapse.empty()) {
63+
// Collapse root.
64+
if (auto itOp = getIterateOpOverSpace(curSpace)) {
65+
CollapseSpaceInfo &info = toCollapse.emplace_back();
66+
info.space = curSpace;
67+
info.loop = itOp;
68+
return true;
69+
}
70+
return false;
71+
}
72+
73+
auto parent = toCollapse.back().space;
74+
auto pItOp = toCollapse.back().loop;
75+
auto nItOp = getIterateOpOverSpace(curSpace);
76+
77+
// Can only collapse spaces extracted from the same tensor.
78+
if (parent.getTensor() != curSpace.getTensor())
79+
return false;
80+
81+
// Can only collapse consecutive simple iteration on one tensor (i.e., no
82+
// coiteration).
83+
if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
84+
pItOp.getIterator() != curSpace.getParentIter() ||
85+
curSpace->getParentOp() != pItOp.getOperation())
86+
return false;
87+
88+
if (pItOp && !isCollapsableLoops(pItOp, nItOp))
89+
return false;
90+
91+
CollapseSpaceInfo &info = toCollapse.emplace_back();
92+
info.space = curSpace;
93+
info.loop = nItOp;
94+
return true;
95+
}
96+
97+
void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
98+
if (toCollapse.size() < 2)
99+
return;
100+
101+
ExtractIterSpaceOp root = toCollapse.front().space;
102+
ExtractIterSpaceOp leaf = toCollapse.back().space;
103+
Location loc = root.getLoc();
104+
105+
assert(root->hasOneUse() && leaf->hasOneUse());
106+
107+
// Insert collapsed operation at the same scope as root operation.
108+
OpBuilder builder(root);
109+
110+
// Construct the collapsed iteration space.
111+
auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
112+
loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
113+
leaf.getHiLvl());
114+
115+
auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
116+
auto innermost = toCollapse.back().loop;
117+
118+
IRMapping mapper;
119+
mapper.map(leaf, collapsedSpace.getResultSpace());
120+
for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
121+
mapper.map(std::get<0>(z), std::get<1>(z));
122+
123+
auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
124+
builder.setInsertionPointToStart(cloned.getBody());
125+
126+
LevelSet crdUsedLvls;
127+
unsigned shift = 0, argIdx = 1;
128+
for (auto info : toCollapse.drop_back()) {
129+
LevelSet set = info.loop.getCrdUsedLvls();
130+
crdUsedLvls |= set.lshift(shift);
131+
shift += info.loop.getSpaceDim();
132+
for (BlockArgument crd : info.loop.getCrds()) {
133+
BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
134+
argIdx++, builder.getIndexType(), crd.getLoc());
135+
crd.replaceAllUsesWith(collapsedCrd);
136+
}
137+
}
138+
crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
139+
cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
140+
cloned.setCrdUsedLvls(crdUsedLvls);
141+
142+
rItOp.replaceAllUsesWith(cloned.getResults());
143+
// Erase collapsed loops.
144+
rItOp.erase();
145+
root.erase();
146+
}
147+
148+
struct SparseSpaceCollapsePass
149+
: public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
150+
SparseSpaceCollapsePass() = default;
151+
152+
void runOnOperation() override {
153+
func::FuncOp func = getOperation();
154+
155+
// A naive (experimental) implementation to collapse consecutive sparse
156+
// spaces. It does NOT handle complex cases where multiple spaces are
157+
// extracted in the same basic block. E.g.,
158+
//
159+
// %space1 = extract_space %t1 ...
160+
// %space2 = extract_space %t2 ...
161+
// sparse_tensor.iterate(%sp1) ...
162+
//
163+
SmallVector<CollapseSpaceInfo> toCollapse;
164+
func->walk([&](ExtractIterSpaceOp op) {
165+
if (!legalToCollapse(toCollapse, op)) {
166+
// if not legal to collapse one more space, collapse the existing ones
167+
// and clear.
168+
collapseSparseSpace(toCollapse);
169+
toCollapse.clear();
170+
}
171+
});
172+
173+
collapseSparseSpace(toCollapse);
174+
}
175+
};
176+
177+
} // namespace sparse_tensor
178+
179+
std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
180+
return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
181+
}
182+
183+
} // namespace mlir
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: mlir-opt %s --sparse-space-collapse | FileCheck %s
2+
3+
#COO = #sparse_tensor.encoding<{
4+
map = (i, j) -> (
5+
i : compressed(nonunique),
6+
j : singleton(soa)
7+
)
8+
}>
9+
10+
// CHECK-LABEL: func.func @sparse_sparse_collapse(
11+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
12+
// CHECK-SAME: %[[VAL_1:.*]]: index) {
13+
// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
14+
// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
15+
// CHECK: %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
16+
// CHECK: sparse_tensor.yield %[[VAL_8]] : index
17+
// CHECK: }
18+
// CHECK: "test.sink"(%[[VAL_4]]) : (index) -> ()
19+
// CHECK: return
20+
// CHECK: }
21+
func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
22+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
23+
%r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
24+
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
25+
%r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
26+
%k ="test.op"(%inner) : (index) -> index
27+
sparse_tensor.yield %k : index
28+
}
29+
sparse_tensor.yield %r2 : index
30+
}
31+
"test.sink"(%r1) : (index) -> ()
32+
return
33+
}

0 commit comments

Comments
 (0)