Skip to content

Commit a905cb2

Browse files
author
Peiming Liu
committed
[mlir][sparse] fold sparse convert into producer generic operation.
1 parent d6c4ebb commit a905cb2

File tree

5 files changed

+121
-26
lines changed

5 files changed

+121
-26
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,20 @@ inline MemRefType getMemRefType(T &&t) {
9090
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
9191

9292
/// Returns true iff MLIR operand has any sparse operand.
93-
inline bool hasAnySparseOperand(Operation *op) {
94-
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
95-
return getSparseTensorEncoding(t) != nullptr;
93+
inline bool hasAnySparseType(TypeRange types) {
94+
return llvm::any_of(types, [](Type type) {
95+
return getSparseTensorEncoding(type) != nullptr;
9696
});
9797
}
9898

99+
/// Returns true iff MLIR operand has any sparse operand.
100+
inline bool hasAnySparseOperand(Operation *op) {
101+
return hasAnySparseType(op->getOperands().getTypes());
102+
}
103+
99104
/// Returns true iff MLIR operand has any sparse result.
100105
inline bool hasAnySparseResult(Operation *op) {
101-
return llvm::any_of(op->getResults().getTypes(), [](Type t) {
102-
return getSparseTensorEncoding(t) != nullptr;
103-
});
106+
return hasAnySparseType(op->getResults().getTypes());
104107
}
105108

106109
/// Returns true iff MLIR operand has any sparse operand or result.

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,37 @@ struct FuseExtractSliceWithConcat
289289
}
290290
};
291291

292+
/// Rewriting rule that converts direct yield of zero with initial allocation.
293+
struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294+
public:
295+
using OpRewritePattern::OpRewritePattern;
296+
297+
LogicalResult matchAndRewrite(ConvertOp op,
298+
PatternRewriter &rewriter) const override {
299+
auto producer = op.getSource().getDefiningOp<GenericOp>();
300+
if (!producer || producer.getDpsInits().size() != 1 ||
301+
!isMaterializing(producer.getDpsInitOperand(0), false) ||
302+
!producer.getResult(0).hasOneUse()) {
303+
return failure();
304+
}
305+
rewriter.modifyOpInPlace(producer, [&]() {
306+
producer.getResult(0).setType(op.getResult().getType());
307+
});
308+
309+
Operation *materializeOp =
310+
producer.getDpsInitOperand(0)->get().getDefiningOp();
311+
312+
rewriter.modifyOpInPlace(materializeOp, [&]() {
313+
materializeOp->getResult(0).setType(op.getResult().getType());
314+
});
315+
316+
rewriter.replaceAllOpUsesWith(op, producer);
317+
op->erase();
318+
319+
return success();
320+
}
321+
};
322+
292323
/// Rewriting rule that converts direct yield of zero with initial allocation.
293324
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
294325
public:
@@ -1506,9 +1537,10 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
15061537
//===---------------------------------------------------------------------===//
15071538

15081539
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
1509-
patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
1510-
FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
1511-
GenSemiRingSelect, PrintRewriter>(patterns.getContext());
1540+
patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1541+
FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1542+
GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1543+
patterns.getContext());
15121544
}
15131545

15141546
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
403403
return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
404404
}
405405

406+
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
407+
Value sparseOut, ValueRange ivs, Value v) {
408+
scf::IfOp condInsert =
409+
builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
410+
// True branch.
411+
builder.setInsertionPointToStart(condInsert.thenBlock());
412+
Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
413+
builder.create<scf::YieldOp>(loc, res);
414+
// False branch.
415+
builder.setInsertionPointToStart(condInsert.elseBlock());
416+
builder.create<scf::YieldOp>(loc, sparseOut);
417+
// Value assignment.
418+
builder.setInsertionPointAfter(condInsert);
419+
return condInsert.getResult(0);
420+
}
421+
406422
/// Generates insertion code to implement dynamic tensor store.
407423
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
408424
Value rhs) {
@@ -423,23 +439,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
423439
// return updated chain
424440
// else
425441
// return unmodified chain
426-
scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
427-
loc, chain.getType(), env.getValidLexInsert(),
428-
/*else=*/true);
429-
// True branch.
430-
builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
431-
Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
432-
builder.create<scf::YieldOp>(loc, res);
433-
// False branch.
434-
builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
435-
builder.create<scf::YieldOp>(loc, chain);
436-
// Value assignment.
437-
builder.setInsertionPointAfter(ifValidLexInsert);
438-
env.updateInsertionChain(ifValidLexInsert.getResult(0));
442+
Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
443+
chain, ivs, rhs);
444+
env.updateInsertionChain(out);
439445
} else {
446+
Value sparseOut;
447+
if (!hasAnySparseType(env.op().getInputs().getTypes())) {
448+
// This is an all-dense -> sparse kernel, test rhs != 0 before
449+
// insertion.
450+
Value nz = genIsNonzero(builder, loc, rhs);
451+
sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
452+
} else {
453+
sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
454+
}
440455
// Generates regular insertion chain.
441-
env.updateInsertionChain(
442-
builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
456+
env.updateInsertionChain(sparseOut);
443457
}
444458
return;
445459
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
2+
3+
#trait = {
4+
indexing_maps = [
5+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
6+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
7+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
8+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
9+
],
10+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
11+
}
12+
13+
#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
14+
15+
// CHECK-LABEL: func.func @test(
16+
// CHECK: scf.for
17+
// CHECK: scf.for
18+
// CHECK: scf.for
19+
// CHECK: scf.if
20+
// CHECK-NEXT: tensor.insert
21+
// CHECK-NEXT: scf.yield
22+
// CHECK-NEXT: else
23+
// CHECK-NEXT: scf.yield
24+
// CHECK: scf.yield
25+
// CHECK: scf.yield
26+
// CHECK: scf.yield
27+
// CHECK: sparse_tensor.load
28+
func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> {
29+
%cst = arith.constant 0.000000e+00 : f32
30+
%cst_0 = arith.constant 1.000000e+00 : f32
31+
%cst_1 = arith.constant 1.000000e+00 : f32
32+
%0 = tensor.empty() : tensor<128x32x32x1xf32>
33+
%1 = linalg.generic #trait
34+
ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
35+
outs(%0 : tensor<128x32x32x1xf32>) {
36+
^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
37+
%3 = arith.subf %cst_0, %in_2 : f32
38+
%4 = arith.mulf %in, %3 : f32
39+
%5 = arith.mulf %4, %cst_1 : f32
40+
%6 = arith.addf %5, %in_3 : f32
41+
%7 = arith.subf %6, %cst_0 : f32
42+
%8 = arith.cmpf uge, %7, %cst : f32
43+
%9 = arith.uitofp %8 : i1 to f32
44+
linalg.yield %9 : f32
45+
} -> tensor<128x32x32x1xf32>
46+
%2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse>
47+
return %2 : tensor<128x32x32x1xf32, #sparse>
48+
}

mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ module {
2424
// CHECK: arith.constant
2525
// CHECK: tensor.empty()
2626
// CHECK: linalg.generic
27-
// CHECK: sparse_tensor.convert
2827
// CHECK: return
2928
//
3029
func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
@@ -44,4 +43,3 @@ module {
4443
return %cast : tensor<10x20x30xf64, #sparse>
4544
}
4645
}
47-

0 commit comments

Comments
 (0)