Skip to content

Commit 55270f5

Browse files
author
Peiming Liu
committed
[mlir][sparse] fix a bug in unpack op that used wrong compare predicate.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D145603
1 parent ffdd5a3 commit 55270f5

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,8 @@ static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
578578

579579
Value targetLen = constantIndex(builder, loc, len);
580580
Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
581-
Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
581+
// Reallocates if target length is greater than the actual buffer len.
582+
Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
582583
targetLen, bufferLen);
583584
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
584585
// If targetLen > bufferLen, reallocate to get enough sparse to return.

mlir/test/Dialect/SparseTensor/sparse_pack.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
4343
// CHECK: %[[VAL_4:.*]] = arith.constant 6 : index
4444
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
4545
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
46-
// CHECK: %[[VAL_7:.*]] = arith.cmpi ult, %[[VAL_4]], %[[VAL_6]] : index
46+
// CHECK: %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index
4747
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
4848
// CHECK: %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
4949
// CHECK: scf.yield %[[VAL_9]] : memref<6xf64>
@@ -53,7 +53,7 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
5353
// CHECK: }
5454
// CHECK: %[[VAL_11:.*]] = arith.constant 12 : index
5555
// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
56-
// CHECK: %[[VAL_13:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
56+
// CHECK: %[[VAL_13:.*]] = arith.cmpi ugt, %[[VAL_11]], %[[VAL_12]] : index
5757
// CHECK: %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) {
5858
// CHECK: %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
5959
// CHECK: scf.yield %[[VAL_15]] : memref<12xi32>

0 commit comments

Comments
 (0)