Skip to content

Commit bfa3bc4

Browse files
authored
[mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. (#66722)
The use cases of the two operations are largely overlapped, let's simplify it and only use one of them.
1 parent 74338bf commit bfa3bc4

14 files changed

+269
-680
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 13 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -762,81 +762,32 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
762762
// Sparse Tensor Sorting Operations.
763763
//===----------------------------------------------------------------------===//
764764

765-
def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
766-
Arguments<(ins Index:$n,
767-
Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
768-
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
769-
SparseTensorSortKindAttr:$algorithm)> {
770-
string summary = "Sorts the arrays in xs and ys lexicographically on the "
771-
"integral values found in the xs list";
772-
string description = [{
773-
Lexicographically sort the first `n` values in `xs` along with the values in
774-
`ys`. Conceptually, the values being sorted are tuples produced by
775-
`zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
776-
along with values in `xs`, but values in `ys` don't affect the
777-
lexicographical order. The order in which arrays appear in `xs` affects the
778-
sorting result. The operator updates `xs` and `ys` in place with the result
779-
of the sorting.
780-
781-
For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
782-
"sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
783-
output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
784-
785-
Buffers in `xs` needs to have the same integral element type while buffers
786-
in `ys` can have different numeric element types. All buffers in `xs` and
787-
`ys` should have a dimension not less than `n`. The behavior of the operator
788-
is undefined if this condition is not met. The operator requires at least
789-
one buffer in `xs` while `ys` can be empty.
790-
791-
The enum attribute `algorithm` indicates the sorting algorithm used to
792-
implement the operator: hybrid_quick_sort, insertion_sort_stable,
793-
quick_sort, or heap_sort.
794-
795-
Note that this operation is "impure" in the sense that its behavior is
796-
solely defined by side-effects and not SSA values.
797-
798-
Example:
799-
800-
```mlir
801-
sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
802-
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
803-
```
804-
805-
```mlir
806-
sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
807-
{ alg=1 : index}
808-
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
809-
```
810-
}];
811-
let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
812-
"`:` type($xs) (`jointly` type($ys)^)?";
813-
let hasVerifier = 1;
814-
}
815-
816765
def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
817766
Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
818767
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
819-
OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
768+
AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
820769
SparseTensorSortKindAttr:$algorithm)> {
821770
let summary = "Sorts the arrays in xs and ys lexicographically on the "
822771
"integral values found in the xs list";
823772
let description = [{
824-
Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
825-
`xs` values and some `ys` values are put in the linear buffer `xy`. The
826-
optional index attribute `nx` provides the number of `xs` values in `xy`.
827-
When `nx` is not explicitly specified, its value is 1. The optional index
828-
attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
829-
explicitly specified, its value is 0. This instruction supports a more
830-
efficient way to store the COO definition in sparse tensor type.
831-
832-
The buffer xy should have a dimension not less than n * (nx + ny) while the
773+
Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
774+
that are put in a single linear buffer `xy`.
775+
The affine map attribute `perm_map` specifies the permutation to be applied on
776+
the `xs` before comparison, the rank of the permutation map
777+
also specifies the number of `xs` values in `xy`.
778+
The optional index attribute `ny` provides the number of `ys` values in `xy`.
779+
When `ny` is not explicitly specified, its value is 0.
780+
This instruction supports a more efficient way to store the COO definition
781+
in sparse tensor type.
782+
783+
The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the
833784
buffers in `ys` should have a dimension not less than `n`. The behavior of
834785
the operator is undefined if this condition is not met.
835786

836787
Example:
837788

838789
```mlir
839-
sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
790+
sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
840791
: memref<?xindex>
841792
```
842793

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,48 +1353,22 @@ LogicalResult SelectOp::verify() {
13531353
return success();
13541354
}
13551355

1356-
LogicalResult SortOp::verify() {
1357-
if (getXs().empty())
1358-
return emitError("need at least one xs buffer.");
1359-
1360-
std::optional<int64_t> n = getConstantIntValue(getN());
1361-
1362-
Type xtp = getMemRefType(getXs().front()).getElementType();
1363-
auto checkTypes = [&](ValueRange operands,
1364-
bool checkEleType = true) -> LogicalResult {
1365-
for (Value opnd : operands) {
1366-
auto mtp = getMemRefType(opnd);
1367-
const DynSize sh = mtp.getShape()[0];
1368-
// We can't check the size of dynamic dimension at compile-time, but all
1369-
// xs and ys should have a dimension not less than n at runtime.
1370-
if (n && !ShapedType::isDynamic(sh) && sh < n.value())
1371-
return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
1372-
": {0} < {1}",
1373-
sh, n.value()));
1374-
1375-
if (checkEleType && xtp != mtp.getElementType())
1376-
return emitError("mismatch xs element types");
1377-
}
1378-
return success();
1379-
};
1380-
RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
1381-
return n ? checkTypes(getYs(), false) : success();
1382-
}
1383-
13841356
LogicalResult SortCooOp::verify() {
1357+
AffineMap xPerm = getPermMap();
1358+
uint64_t nx = xPerm.getNumDims();
1359+
if (nx < 1)
1360+
emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
1361+
1362+
if (!xPerm.isPermutation())
1363+
emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
1364+
13851365
std::optional<int64_t> cn = getConstantIntValue(getN());
13861366
// We can't check the size of the buffers when n or buffer dimensions aren't
13871367
// compile-time constants.
13881368
if (!cn)
13891369
return success();
13901370

13911371
uint64_t n = cn.value();
1392-
uint64_t nx = 1;
1393-
if (auto nxAttr = getNxAttr()) {
1394-
nx = nxAttr.getInt();
1395-
if (nx < 1)
1396-
emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
1397-
}
13981372
uint64_t ny = 0;
13991373
if (auto nyAttr = getNyAttr()) {
14001374
ny = nyAttr.getInt();
@@ -1409,7 +1383,8 @@ LogicalResult SortCooOp::verify() {
14091383
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
14101384
};
14111385

1412-
checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
1386+
checkDim(getXy(), n * (nx + ny),
1387+
"Expected dimension(xy) >= n * (rank(perm_map) + ny)");
14131388

14141389
for (Value opnd : getYs()) {
14151390
checkDim(opnd, n, "Expected dimension(y) >= n");

0 commit comments

Comments
 (0)