Skip to content

Commit 7cbaaed

Browse files
author
Peiming Liu
authored
[mlir][sparse] fix sparse tests that uses reshape operations. (#90637)
Due to generalization introduced in #90040
1 parent 41f9c78 commit 7cbaaed

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ module {
4444
%cst = arith.constant 0.000000e+00 : f32
4545
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
4646
%2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32>, tensor<6x6xf32, #COO_2D>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32>
47-
%expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
47+
%expanded = tensor.expand_shape %2 [[0], [1, 2]] output_shape [5,2,3]: tensor<5x6xf32> into tensor<5x2x3xf32>
4848
%ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
4949

5050
// Note: tensor.collapse_shape is a metadata-only operation on dense tensors
@@ -60,7 +60,7 @@ module {
6060
%cst = arith.constant 0.000000e+00 : f32
6161
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
6262
%2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32, #COO_2D>, tensor<6x6xf32, #COO_2D>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32>
63-
%expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
63+
%expanded = tensor.expand_shape %2 [[0], [1, 2]] output_shape [5,2,3]: tensor<5x6xf32> into tensor<5x2x3xf32>
6464
%ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
6565

6666
// Note: tensor.collapse_shape is a metadata-only operation on dense tensors
@@ -76,7 +76,7 @@ module {
7676
%cst = arith.constant 0.000000e+00 : f32
7777
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
7878
%2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32>, tensor<6x6xf32>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32>
79-
%expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
79+
%expanded = tensor.expand_shape %2 [[0], [1, 2]] output_shape [5,2,3]: tensor<5x6xf32> into tensor<5x2x3xf32>
8080
%ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
8181
return %ret1 : tensor<?x?x?xf32>
8282
}
@@ -88,7 +88,7 @@ module {
8888
%cst = arith.constant 0.000000e+00 : f32
8989
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
9090
%2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32, #COO_2D>, tensor<6x6xf32, #COO_2D>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32>
91-
%expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
91+
%expanded = tensor.expand_shape %2 [[0], [1, 2]] output_shape [5,2,3]: tensor<5x6xf32> into tensor<5x2x3xf32>
9292
%ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
9393

9494
// Note: tensor.collapse_shape is a metadata-only operation on dense tensors

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_expand_shape.mlir

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,62 +53,86 @@
5353
module {
5454

5555
func.func @expand_dense(%arg0: tensor<12xf64>) -> tensor<3x4xf64> {
56-
%0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64> into tensor<3x4xf64>
56+
%0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [3, 4] : tensor<12xf64> into tensor<3x4xf64>
5757
return %0 : tensor<3x4xf64>
5858
}
5959

6060
func.func @expand_from_sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64> {
61-
%0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64>
61+
%0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [3, 4] : tensor<12xf64, #SparseVector> into tensor<3x4xf64>
6262
return %0 : tensor<3x4xf64>
6363
}
6464

6565
func.func @expand_to_sparse(%arg0: tensor<12xf64>) -> tensor<3x4xf64, #SparseMatrix> {
66-
%0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64> into tensor<3x4xf64, #SparseMatrix>
66+
%0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [3, 4] : tensor<12xf64> into tensor<3x4xf64, #SparseMatrix>
6767
return %0 : tensor<3x4xf64, #SparseMatrix>
6868
}
6969

7070
func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> {
71-
%0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix>
71+
%0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [3, 4] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix>
7272
return %0 : tensor<3x4xf64, #SparseMatrix>
7373
}
7474

7575
func.func @expand_dense_3x2x2(%arg0: tensor<3x4xf64>) -> tensor<3x2x2xf64> {
76-
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<3x4xf64> into tensor<3x2x2xf64>
76+
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [3, 2, 2] : tensor<3x4xf64> into tensor<3x2x2xf64>
7777
return %0 : tensor<3x2x2xf64>
7878
}
7979

8080
func.func @expand_from_sparse_3x2x2(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<3x2x2xf64> {
81-
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<3x4xf64, #SparseMatrix> into tensor<3x2x2xf64>
81+
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [3, 2, 2] : tensor<3x4xf64, #SparseMatrix> into tensor<3x2x2xf64>
8282
return %0 : tensor<3x2x2xf64>
8383
}
8484

8585
func.func @expand_to_sparse_3x2x2(%arg0: tensor<3x4xf64>) -> tensor<3x2x2xf64, #Sparse3dTensor> {
86-
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<3x4xf64> into tensor<3x2x2xf64, #Sparse3dTensor>
86+
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [3, 2, 2] : tensor<3x4xf64> into tensor<3x2x2xf64, #Sparse3dTensor>
8787
return %0 : tensor<3x2x2xf64, #Sparse3dTensor>
8888
}
8989

9090
func.func @expand_sparse2sparse_3x2x2(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<3x2x2xf64, #Sparse3dTensor> {
91-
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<3x4xf64, #SparseMatrix> into tensor<3x2x2xf64, #Sparse3dTensor>
91+
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [3, 2, 2] : tensor<3x4xf64, #SparseMatrix> into tensor<3x2x2xf64, #Sparse3dTensor>
9292
return %0 : tensor<3x2x2xf64, #Sparse3dTensor>
9393
}
9494

9595
func.func @expand_dense_dyn(%arg0: tensor<?x?xf64>) -> tensor<?x2x?xf64> {
96-
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<?x?xf64> into tensor<?x2x?xf64>
96+
%c0 = arith.constant 0 : index
97+
%c1 = arith.constant 1 : index
98+
%c2 = arith.constant 2 : index
99+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf64>
100+
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf64>
101+
%d2 = arith.divui %d1, %c2 : index
102+
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%d0, 2, %d2] : tensor<?x?xf64> into tensor<?x2x?xf64>
97103
return %0 : tensor<?x2x?xf64>
98104
}
99105

100106
func.func @expand_from_sparse_dyn(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x2x?xf64> {
101-
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<?x?xf64, #SparseMatrix> into tensor<?x2x?xf64>
107+
%c0 = arith.constant 0 : index
108+
%c1 = arith.constant 1 : index
109+
%c2 = arith.constant 2 : index
110+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf64, #SparseMatrix>
111+
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf64, #SparseMatrix>
112+
%d2 = arith.divui %d1, %c2 : index
113+
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%d0, 2, %d2] : tensor<?x?xf64, #SparseMatrix> into tensor<?x2x?xf64>
102114
return %0 : tensor<?x2x?xf64>
103115
}
104116

105117
func.func @expand_to_sparse_dyn(%arg0: tensor<?x?xf64>) -> tensor<?x2x?xf64, #Sparse3dTensor> {
106-
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<?x?xf64> into tensor<?x2x?xf64, #Sparse3dTensor>
118+
%c0 = arith.constant 0 : index
119+
%c1 = arith.constant 1 : index
120+
%c2 = arith.constant 2 : index
121+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf64>
122+
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf64>
123+
%d2 = arith.divui %d1, %c2 : index
124+
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%d0, 2, %d2] : tensor<?x?xf64> into tensor<?x2x?xf64, #Sparse3dTensor>
107125
return %0 : tensor<?x2x?xf64, #Sparse3dTensor>
108126
}
109127

110128
func.func @expand_sparse2sparse_dyn(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x2x?xf64, #Sparse3dTensor> {
111-
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<?x?xf64, #SparseMatrix> into tensor<?x2x?xf64, #Sparse3dTensor>
129+
%c0 = arith.constant 0 : index
130+
%c1 = arith.constant 1 : index
131+
%c2 = arith.constant 2 : index
132+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf64, #SparseMatrix>
133+
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf64, #SparseMatrix>
134+
%d2 = arith.divui %d1, %c2 : index
135+
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%d0, 2, %d2] : tensor<?x?xf64, #SparseMatrix> into tensor<?x2x?xf64, #Sparse3dTensor>
112136
return %0 : tensor<?x2x?xf64, #Sparse3dTensor>
113137
}
114138

0 commit comments

Comments
 (0)