Skip to content

Commit c689fbb

Browse files
author
Peiming Liu
committed
address comments
1 parent a905cb2 commit c689fbb

File tree

4 files changed

+37
-52
lines changed

4 files changed

+37
-52
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ inline MemRefType getMemRefType(T &&t) {
8989
/// Returns null-attribute for any type without an encoding.
9090
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
9191

92-
/// Returns true iff MLIR operand has any sparse operand.
92+
/// Returns true iff the type range has any sparse tensor type.
9393
inline bool hasAnySparseType(TypeRange types) {
9494
return llvm::any_of(types, [](Type type) {
9595
return getSparseTensorEncoding(type) != nullptr;

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ struct FuseExtractSliceWithConcat
289289
}
290290
};
291291

292-
/// Rewriting rule that converts direct yield of zero with initial allocation.
292+
/// Rewriting rule that fuses sparse_tensor.convert into producer.
293293
struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294294
public:
295295
using OpRewritePattern::OpRewritePattern;

mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map | FileCheck %s --check-prefix=CHECK-FOLD
12
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
23

34
#trait = {
@@ -10,9 +11,12 @@
1011
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
1112
}
1213

13-
#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
14+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
1415

15-
// CHECK-LABEL: func.func @test(
16+
#COO = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa))}>
17+
#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
18+
19+
// CHECK-LABEL: func.func @fold_convert(
1620
// CHECK: scf.for
1721
// CHECK: scf.for
1822
// CHECK: scf.for
@@ -25,7 +29,10 @@
2529
// CHECK: scf.yield
2630
// CHECK: scf.yield
2731
// CHECK: sparse_tensor.load
28-
func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> {
32+
33+
// CHECK-FOLD-LABEL: func.func @fold_convert(
34+
// CHECK-FOLD-NOT: sparse_tensor.convert
35+
func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #CCCD> {
2936
%cst = arith.constant 0.000000e+00 : f32
3037
%cst_0 = arith.constant 1.000000e+00 : f32
3138
%cst_1 = arith.constant 1.000000e+00 : f32
@@ -43,6 +50,29 @@ func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>,
4350
%9 = arith.uitofp %8 : i1 to f32
4451
linalg.yield %9 : f32
4552
} -> tensor<128x32x32x1xf32>
46-
%2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse>
47-
return %2 : tensor<128x32x32x1xf32, #sparse>
53+
%2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
54+
return %2 : tensor<128x32x32x1xf32, #CCCD>
55+
}
56+
57+
58+
// FIXME: The following kernel is not sparsifiable because `arith.select`
59+
// operations is not handled by the sparse compiler at the moment.
60+
//
61+
// CHECK-FOLD-LABEL: func.func @fold_cast(
62+
// CHECK-FOLD-NOT: sparse_tensor.convert
63+
func.func @fold_cast(%0: tensor<10x20x30xf64, #COO>) -> tensor<10x20x30xf64, #COO> {
64+
%cst = arith.constant 0.000000e+00 : f64
65+
%1 = tensor.empty() : tensor<10x20x30xf64>
66+
%2 = linalg.generic { indexing_maps = [#map, #map],
67+
iterator_types = ["parallel", "parallel", "parallel"]
68+
}
69+
ins (%0 : tensor<10x20x30xf64, #COO>)
70+
outs(%1 : tensor<10x20x30xf64>) {
71+
^bb0(%in: f64, %out: f64):
72+
%4 = arith.cmpf ugt, %in, %cst : f64
73+
%5 = arith.select %4, %in, %cst : f64
74+
linalg.yield %5 : f64
75+
} -> tensor<10x20x30xf64>
76+
%cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #COO>
77+
return %cast : tensor<10x20x30xf64, #COO>
4878
}

mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)