Skip to content

Commit 836b6b8

Browse files
author
Peiming Liu
committed
rebase
1 parent 6af6167 commit 836b6b8

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
5353
ExtractIterSpaceOp curSpace) {
5454

5555
auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
56-
Value spaceVal = space.getResultSpace();
56+
Value spaceVal = space.getExtractedSpace();
5757
if (spaceVal.hasOneUse())
5858
return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
5959
return nullptr;
@@ -116,7 +116,7 @@ void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
116116
auto innermost = toCollapse.back().loop;
117117

118118
IRMapping mapper;
119-
mapper.map(leaf, collapsedSpace.getResultSpace());
119+
mapper.map(leaf, collapsedSpace.getExtractedSpace());
120120
for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
121121
mapper.map(std::get<0>(z), std::get<1>(z));
122122

mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@
1919
// CHECK: return
2020
// CHECK: }
2121
func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
22-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
22+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
23+
: tensor<4x8xf32, #COO>
24+
-> !sparse_tensor.iter_space<#COO, lvls = 0>
2325
%r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
24-
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
26+
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
27+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
28+
-> !sparse_tensor.iter_space<#COO, lvls = 1>
2529
%r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
2630
%k ="test.op"(%inner) : (index) -> index
2731
sparse_tensor.yield %k : index

0 commit comments

Comments
 (0)