1
+ // RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map | FileCheck %s --check-prefix=CHECK-FOLD
1
2
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
2
3
3
4
#trait = {
10
11
iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]
11
12
}
12
13
13
- #sparse = #sparse_tensor.encoding <{ map = (d0 , d1 , d2 , d 3 ) -> (d0 : compressed , d1 : compressed , d2 : compressed , d 3 : dense) } >
14
+ #map = affine_map < (d0 , d1 , d2 ) -> (d0 , d1 , d2 ) >
14
15
15
- // CHECK-LABEL: func.func @test(
16
+ #COO = #sparse_tensor.encoding <{map = (d0 , d1 , d2 ) -> (d0 : compressed(nonunique), d1 : singleton(nonunique , soa), d2 : singleton(soa))}>
17
+ #CCCD = #sparse_tensor.encoding <{ map = (d0 , d1 , d2 , d3 ) -> (d0 : compressed , d1 : compressed , d2 : compressed , d3 : dense) }>
18
+
19
+ // CHECK-LABEL: func.func @fold_convert(
16
20
// CHECK: scf.for
17
21
// CHECK: scf.for
18
22
// CHECK: scf.for
25
29
// CHECK: scf.yield
26
30
// CHECK: scf.yield
27
31
// CHECK: sparse_tensor.load
28
- func.func @test (%arg0: tensor <128 x32 x32 x1 xf32 >, %arg1: tensor <128 x32 x32 x1 xf32 >, %arg2: tensor <128 x32 x32 x1 xf32 >) -> tensor <128 x32 x32 x1 xf32 , #sparse > {
32
+
33
+ // CHECK-FOLD-LABEL: func.func @fold_convert(
34
+ // CHECK-FOLD-NOT: sparse_tensor.convert
35
+ func.func @fold_convert (%arg0: tensor <128 x32 x32 x1 xf32 >, %arg1: tensor <128 x32 x32 x1 xf32 >, %arg2: tensor <128 x32 x32 x1 xf32 >) -> tensor <128 x32 x32 x1 xf32 , #CCCD > {
29
36
%cst = arith.constant 0.000000e+00 : f32
30
37
%cst_0 = arith.constant 1.000000e+00 : f32
31
38
%cst_1 = arith.constant 1.000000e+00 : f32
@@ -43,6 +50,29 @@ func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>,
43
50
%9 = arith.uitofp %8 : i1 to f32
44
51
linalg.yield %9 : f32
45
52
} -> tensor <128 x32 x32 x1 xf32 >
46
- %2 = sparse_tensor.convert %1 : tensor <128 x32 x32 x1 xf32 > to tensor <128 x32 x32 x1 xf32 , #sparse >
47
- return %2 : tensor <128 x32 x32 x1 xf32 , #sparse >
53
+ %2 = sparse_tensor.convert %1 : tensor <128 x32 x32 x1 xf32 > to tensor <128 x32 x32 x1 xf32 , #CCCD >
54
+ return %2 : tensor <128 x32 x32 x1 xf32 , #CCCD >
55
+ }
56
+
57
+
58
+ // FIXME: The following kernel is not sparsifiable because `arith.select`
59
+ // operations is not handled by the sparse compiler at the moment.
60
+ //
61
+ // CHECK-FOLD-LABEL: func.func @fold_cast(
62
+ // CHECK-FOLD-NOT: sparse_tensor.convert
63
+ func.func @fold_cast (%0: tensor <10 x20 x30 xf64 , #COO >) -> tensor <10 x20 x30 xf64 , #COO > {
64
+ %cst = arith.constant 0.000000e+00 : f64
65
+ %1 = tensor.empty () : tensor <10 x20 x30 xf64 >
66
+ %2 = linalg.generic { index ing_maps = [#map , #map ],
67
+ iterator_types = [" parallel" , " parallel" , " parallel" ]
68
+ }
69
+ ins (%0 : tensor <10 x20 x30 xf64 , #COO >)
70
+ outs (%1 : tensor <10 x20 x30 xf64 >) {
71
+ ^bb0 (%in: f64 , %out: f64 ):
72
+ %4 = arith.cmpf ugt , %in , %cst : f64
73
+ %5 = arith.select %4 , %in , %cst : f64
74
+ linalg.yield %5 : f64
75
+ } -> tensor <10 x20 x30 xf64 >
76
+ %cast = tensor.cast %2 : tensor <10 x20 x30 xf64 > to tensor <10 x20 x30 xf64 , #COO >
77
+ return %cast : tensor <10 x20 x30 xf64 , #COO >
48
78
}
0 commit comments