diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 94301dbcd9f7b..59815fc755ee5 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -762,81 +762,32 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>, // Sparse Tensor Sorting Operations. //===----------------------------------------------------------------------===// -def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, - Arguments<(ins Index:$n, - Variadic>:$xs, - Variadic>:$ys, - SparseTensorSortKindAttr:$algorithm)> { - string summary = "Sorts the arrays in xs and ys lexicographically on the " - "integral values found in the xs list"; - string description = [{ - Lexicographically sort the first `n` values in `xs` along with the values in - `ys`. Conceptually, the values being sorted are tuples produced by - `zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted - along with values in `xs`, but values in `ys` don't affect the - lexicographical order. The order in which arrays appear in `xs` affects the - sorting result. The operator updates `xs` and `ys` in place with the result - of the sorting. - - For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of - "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the - output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5]. - - Buffers in `xs` needs to have the same integral element type while buffers - in `ys` can have different numeric element types. All buffers in `xs` and - `ys` should have a dimension not less than `n`. The behavior of the operator - is undefined if this condition is not met. The operator requires at least - one buffer in `xs` while `ys` can be empty. - - The enum attribute `algorithm` indicates the sorting algorithm used to - implement the operator: hybrid_quick_sort, insertion_sort_stable, - quick_sort, or heap_sort. - - Note that this operation is "impure" in the sense that its behavior is - solely defined by side-effects and not SSA values. - - Example: - - ```mlir - sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2 - : memref, memref jointly memref, memref - ``` - - ```mlir - sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2 - { alg=1 : index} - : memref, memref jointly memref, memref - ``` - }]; - let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict" - "`:` type($xs) (`jointly` type($ys)^)?"; - let hasVerifier = 1; -} - def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">, Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy, Variadic>:$ys, - OptionalAttr:$nx, OptionalAttr:$ny, + AffineMapAttr:$perm_map, OptionalAttr:$ny, SparseTensorSortKindAttr:$algorithm)> { let summary = "Sorts the arrays in xs and ys lexicographically on the " "integral values found in the xs list"; let description = [{ - Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the - `xs` values and some `ys` values are put in the linear buffer `xy`. The - optional index attribute `nx` provides the number of `xs` values in `xy`. - When `nx` is not explicitly specified, its value is 1. The optional index - attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not - explicitly specified, its value is 0. This instruction supports a more - efficient way to store the COO definition in sparse tensor type. - - The buffer xy should have a dimension not less than n * (nx + ny) while the + Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values + that are put in a single linear buffer `xy`. + The affine map attribute `perm_map` specifies the permutation to be applied on + the `xs` before comparison, the rank of the permutation map + also specifies the number of `xs` values in `xy`. + The optional index attribute `ny` provides the number of `ys` values in `xy`. + When `ny` is not explicitly specified, its value is 0. + This instruction supports a more efficient way to store the COO definition + in sparse tensor type. + + The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the buffers in `ys` should have a dimension not less than `n`. The behavior of the operator is undefined if this condition is not met. Example: ```mlir - sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index} + sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> } : memref ``` diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index e71d2a8dd623a..9675a61109477 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -1353,35 +1353,15 @@ LogicalResult SelectOp::verify() { return success(); } -LogicalResult SortOp::verify() { - if (getXs().empty()) - return emitError("need at least one xs buffer."); - - std::optional n = getConstantIntValue(getN()); - - Type xtp = getMemRefType(getXs().front()).getElementType(); - auto checkTypes = [&](ValueRange operands, - bool checkEleType = true) -> LogicalResult { - for (Value opnd : operands) { - auto mtp = getMemRefType(opnd); - const DynSize sh = mtp.getShape()[0]; - // We can't check the size of dynamic dimension at compile-time, but all - // xs and ys should have a dimension not less than n at runtime. - if (n && !ShapedType::isDynamic(sh) && sh < n.value()) - return emitError(llvm::formatv("xs and ys need to have a dimension >= n" - ": {0} < {1}", - sh, n.value())); - - if (checkEleType && xtp != mtp.getElementType()) - return emitError("mismatch xs element types"); - } - return success(); - }; - RETURN_FAILURE_IF_FAILED(checkTypes(getXs())) - return n ? checkTypes(getYs(), false) : success(); -} - LogicalResult SortCooOp::verify() { + AffineMap xPerm = getPermMap(); + uint64_t nx = xPerm.getNumDims(); + if (nx < 1) + emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); + + if (!xPerm.isPermutation()) + emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); + std::optional cn = getConstantIntValue(getN()); // We can't check the size of the buffers when n or buffer dimensions aren't // compile-time constants. @@ -1389,12 +1369,6 @@ LogicalResult SortCooOp::verify() { return success(); uint64_t n = cn.value(); - uint64_t nx = 1; - if (auto nxAttr = getNxAttr()) { - nx = nxAttr.getInt(); - if (nx < 1) - emitError(llvm::formatv("Expected nx > 1, got {0}", nx)); - } uint64_t ny = 0; if (auto nyAttr = getNyAttr()) { ny = nyAttr.getInt(); @@ -1409,7 +1383,8 @@ LogicalResult SortCooOp::verify() { emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); }; - checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)"); + checkDim(getXy(), n * (nx + ny), + "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); for (Value opnd : getYs()) { checkDim(opnd, n, "Expected dimension(y) >= n"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index 029ecb0708941..3181395a474cf 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -45,46 +45,43 @@ static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"; -using FuncGeneratorType = function_ref; +using FuncGeneratorType = function_ref; /// Constructs a function name with this format to facilitate quick sort: -/// __..._ for sort -/// __coo__..._ for sort_coo +/// __..._ for sort +/// __coo__..._ for sort_coo static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, - StringRef namePrefix, uint64_t nx, - uint64_t ny, bool isCoo, - ValueRange operands) { - nameOstream << namePrefix << nx << "_" - << getMemRefType(operands[xStartIdx]).getElementType(); + StringRef namePrefix, AffineMap xPerm, + uint64_t ny, ValueRange operands) { + nameOstream << namePrefix; + for (auto res : xPerm.getResults()) + nameOstream << res.cast().getPosition() << "_"; - if (isCoo) - nameOstream << "_coo_" << ny; + nameOstream << getMemRefType(operands[xStartIdx]).getElementType(); + nameOstream << "_coo_" << ny; - uint64_t yBufferOffset = isCoo ? 1 : nx; + constexpr uint64_t yBufferOffset = 1; for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) nameOstream << "_" << getMemRefType(v).getElementType(); } /// Looks up a function that is appropriate for the given operands being /// sorted, and creates such a function if it doesn't exist yet. The -/// parameters `nx` and `ny` tell the number of x and y values provided -/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction -/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo. +/// parameters `xPerm` and `ny` tell the number of x and y values provided +/// by the buffer in xStartIdx. // // All sorting function generators take (lo, hi, xs, ys) in `operands` as // parameters for the sorting functions. Other parameters, such as the recursive // call depth, are appended to the end of the parameter list as // "trailing parameters". -static FlatSymbolRefAttr -getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, - TypeRange resultTypes, StringRef namePrefix, - uint64_t nx, uint64_t ny, bool isCoo, - ValueRange operands, FuncGeneratorType createFunc, - uint32_t nTrailingP = 0) { +static FlatSymbolRefAttr getMangledSortHelperFunc( + OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, + StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, + FuncGeneratorType createFunc, uint32_t nTrailingP = 0) { SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); - getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, + getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, operands.drop_back(nTrailingP)); ModuleOp module = insertPoint->getParentOfType(); @@ -101,7 +98,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, loc, nameOstream.str(), FunctionType::get(context, operands.getTypes(), resultTypes)); func.setPrivate(); - createFunc(builder, module, func, nx, ny, isCoo, nTrailingP); + createFunc(builder, module, func, xPerm, ny, nTrailingP); } return result; @@ -110,27 +107,19 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. /// The code to process the value pairs is generated by `bodyBuilder`. static void forEachIJPairInXs( - OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, - bool isCoo, function_ref bodyBuilder) { - Value iOffset, jOffset; - if (isCoo) { - Value cstep = constantIndex(builder, loc, nx + ny); - iOffset = builder.create(loc, args[0], cstep); - jOffset = builder.create(loc, args[1], cstep); - } - for (uint64_t k = 0; k < nx; k++) { - scf::IfOp ifOp; - Value i, j, buffer; - if (isCoo) { - Value ck = constantIndex(builder, loc, k); - i = builder.create(loc, ck, iOffset); - j = builder.create(loc, ck, jOffset); - buffer = args[xStartIdx]; - } else { - i = args[0]; - j = args[1]; - buffer = args[xStartIdx + k]; - } + OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, + uint64_t ny, + function_ref bodyBuilder) { + Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny); + Value iOffset = builder.create(loc, args[0], cstep); + Value jOffset = builder.create(loc, args[1], cstep); + for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) { + unsigned actualK = xPerm.getResult(k).cast().getPosition(); + Value ak = constantIndex(builder, loc, actualK); + Value i = builder.create(loc, ak, iOffset); + Value j = builder.create(loc, ak, jOffset); + Value buffer = args[xStartIdx]; + bodyBuilder(k, i, j, buffer); } } @@ -138,21 +127,28 @@ static void forEachIJPairInXs( /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. /// The code to process the value pairs is generated by `bodyBuilder`. static void forEachIJPairInAllBuffers( - OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, - bool isCoo, function_ref bodyBuilder) { - - // Create code for the first (nx + ny) buffers. When isCoo==true, these - // logical buffers are all from the xy buffer of the sort_coo operator. - forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder); + OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, + uint64_t ny, + function_ref bodyBuilder) { + + // Create code for the first (xPerm + ny) buffers. + SmallVector exps(xPerm.getResults().begin(), + xPerm.getResults().end()); + for (unsigned y = 0; y < ny; y++) { + exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults())); + } + AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext()); + assert(xyPerm.isPermutation()); - uint64_t numHandledBuffers = isCoo ? 1 : nx + ny; + forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder); + constexpr uint64_t numHandledBuffers = 1; // Create code for the remaining buffers. Value i = args[0]; Value j = args[1]; for (const auto &arg : llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { - bodyBuilder(arg.index() + nx + ny, i, j, arg.value()); + bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value()); } } @@ -168,7 +164,7 @@ static void forEachIJPairInAllBuffers( // ... // swap(yn[i], yn[j]); static void createSwap(OpBuilder &builder, Location loc, ValueRange args, - uint64_t nx, uint64_t ny, bool isCoo) { + AffineMap xPerm, uint64_t ny) { auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { Value vi = builder.create(loc, buffer, i); Value vj = builder.create(loc, buffer, j); @@ -176,20 +172,20 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args, builder.create(loc, vi, buffer, j); }; - forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair); + forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair); } /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare /// each pair is create via `compareBuilder`. static Value createInlinedCompareImplementation( - OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, - bool isCoo, + OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, + uint64_t ny, function_ref compareBuilder) { Value result; auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { bool isFirstDim = (k == 0); - bool isLastDim = (k == nx - 1); + bool isLastDim = (k == xPerm.getNumResults() - 1); Value val = compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); if (isFirstDim) { @@ -202,7 +198,7 @@ static Value createInlinedCompareImplementation( } }; - forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); + forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder); builder.setInsertionPointAfterValue(result); return result; @@ -252,12 +248,12 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, // else if (x2[2] != x2[j])) // and so on ... static Value createInlinedEqCompare(OpBuilder &builder, Location loc, - ValueRange args, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP = 0) { + ValueRange args, AffineMap xPerm, + uint64_t ny, uint32_t nTrailingP = 0) { // Compare functions don't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); - return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo, + return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, createEqCompare); } @@ -306,12 +302,12 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, // else // and so on ... static Value createInlinedLessThan(OpBuilder &builder, Location loc, - ValueRange args, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP = 0) { + ValueRange args, AffineMap xPerm, + uint64_t ny, uint32_t nTrailingP = 0) { // Compare functions don't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); - return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo, + return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, createLessThanCompare); } @@ -329,8 +325,8 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc, // return lo; // static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP = 0) { + func::FuncOp func, AffineMap xPerm, + uint64_t ny, uint32_t nTrailingP = 0) { // Binary search doesn't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); @@ -368,11 +364,10 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, // Compare xs[p] < xs[mid]. SmallVector compareOperands{p, mid}; - uint64_t numXBuffers = isCoo ? 1 : nx; + constexpr uint64_t numXBuffers = 1; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); - Value cond2 = - createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); + Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); // Update lo and hi for the WhileOp as follows: // if (xs[p] < xs[mid])) // hi = mid; @@ -392,10 +387,11 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, /// while (xs[i] > xs[p]) i += step (step < 0) /// The routine returns i as well as a boolean value to indicate whether /// xs[i] == xs[p]. -static std::pair -createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, - ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny, - bool isCoo, int step) { +static std::pair createScanLoop(OpBuilder &builder, + ModuleOp module, + func::FuncOp func, ValueRange xs, + Value i, Value p, AffineMap xPerm, + uint64_t ny, int step) { Location loc = func.getLoc(); scf::WhileOp whileOp = builder.create(loc, TypeRange{i.getType()}, ValueRange{i}); @@ -413,8 +409,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, compareOperands.push_back(before->getArgument(0)); } compareOperands.append(xs.begin(), xs.end()); - Value cond = - createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); + Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); builder.create(loc, cond, before->getArguments()); Block *after = @@ -429,7 +424,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, compareOperands[0] = i; compareOperands[1] = p; Value compareEq = - createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo); + createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny); return std::make_pair(whileOp.getResult(0), compareEq); } @@ -438,67 +433,63 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, /// if compareFunc(data[b], data[a]) returns true. The new insertion point is /// right after the swap instructions. static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, - uint64_t nx, uint64_t ny, bool isCoo, + AffineMap xPerm, uint64_t ny, SmallVectorImpl &swapOperands, SmallVectorImpl &compareOperands, Value a, Value b) { // Compare(data[b], data[a]). compareOperands[0] = b; compareOperands[1] = a; - Value cond = - createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); + Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); swapOperands[0] = b; swapOperands[1] = a; - createSwap(builder, loc, swapOperands, nx, ny, isCoo); + createSwap(builder, loc, swapOperands, xPerm, ny); return ifOp; } /// Creates code to insert the 3rd element to a list of two sorted elements. -static void createInsert3rd(OpBuilder &builder, Location loc, uint64_t nx, - uint64_t ny, bool isCoo, - SmallVectorImpl &swapOperands, +static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, + uint64_t ny, SmallVectorImpl &swapOperands, SmallVectorImpl &compareOperands, Value v0, Value v1, Value v2) { - scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo, - swapOperands, compareOperands, v1, v2); - createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands, - compareOperands, v0, v1); + scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, + compareOperands, v1, v2); + createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, + v0, v1); builder.setInsertionPointAfter(ifOp); } /// Creates code to sort 3 elements. -static void createSort3(OpBuilder &builder, Location loc, uint64_t nx, - uint64_t ny, bool isCoo, - SmallVectorImpl &swapOperands, +static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, + uint64_t ny, SmallVectorImpl &swapOperands, SmallVectorImpl &compareOperands, Value v0, Value v1, Value v2) { // Sort the first 2 elements. - scf::IfOp ifOp1 = createCompareThenSwap( - builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0, v1); + scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, + compareOperands, v0, v1); builder.setInsertionPointAfter(ifOp1); // Insert the 3th element. - createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, - v0, v1, v2); + createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, + v1, v2); } /// Creates code to sort 5 elements. -static void createSort5(OpBuilder &builder, Location loc, uint64_t nx, - uint64_t ny, bool isCoo, - SmallVectorImpl &swapOperands, +static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, + uint64_t ny, SmallVectorImpl &swapOperands, SmallVectorImpl &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4) { // Sort the first 3 elements. - createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0, - v1, v2); + createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, + v2); auto insert4th = [&]() { scf::IfOp ifOp = createCompareThenSwap( - builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v2, v3); - createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, - v0, v1, v2); + builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3); + createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, + v1, v2); builder.setInsertionPointAfter(ifOp); }; @@ -506,8 +497,8 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx, insert4th(); // Insert the 5th element. - scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo, - swapOperands, compareOperands, v3, v4); + scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, + compareOperands, v3, v4); insert4th(); builder.setInsertionPointAfter(ifOp); } @@ -517,11 +508,10 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx, /// the number of values in range [lo, hi) is more than a threshold, we also /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. static void createChoosePivot(OpBuilder &builder, ModuleOp module, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, Value lo, Value hi, Value mi, - ValueRange args) { + func::FuncOp func, AffineMap xPerm, uint64_t ny, + Value lo, Value hi, Value mi, ValueRange args) { SmallVector compareOperands{mi, lo}; - uint64_t numXBuffers = isCoo ? 1 : nx; + constexpr uint64_t numXBuffers = 1; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); SmallVector swapOperands{mi, lo}; @@ -537,8 +527,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, // When len < 1000, choose pivot from median of 3 values. builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); - createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo, - mi, hi); + createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi, + hi); // When len >= 1000, choose pivot from median of 5 values. builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); @@ -549,8 +539,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, Value b = builder.create(loc, mi, hiP1); // Value b is the middle between [mi, hi]. b = builder.create(loc, b, c1); - createSort5(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo, a, - mi, b, hi); + createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi, + b, hi); builder.setInsertionPointAfter(lenIf); } @@ -586,8 +576,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module, // } // } static void createPartitionFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP = 0) { + func::FuncOp func, AffineMap xPerm, uint64_t ny, + uint32_t nTrailingP = 0) { // Quick sort partition doesn't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); @@ -606,7 +596,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module, Value i = lo; Value j = builder.create(loc, hi, c1); - createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args); + createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args); Value trueVal = constantI1(builder, loc, true); // The value for while (true) SmallVector operands{i, j, p, trueVal}; // Exactly four values. SmallVector types{i.getType(), j.getType(), p.getType(), @@ -628,14 +618,14 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module, j = after->getArgument(1); p = after->getArgument(2); - uint64_t numXBuffers = isCoo ? 1 : nx; + constexpr uint64_t numXBuffers = 1; auto [iresult, iCompareEq] = createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), - i, p, nx, ny, isCoo, 1); + i, p, xPerm, ny, 1); i = iresult; auto [jresult, jCompareEq] = createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), - j, p, nx, ny, isCoo, -1); + j, p, xPerm, ny, -1); j = jresult; // If i < j: @@ -645,7 +635,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module, builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); SmallVector swapOperands{i, j}; swapOperands.append(args.begin() + xStartIdx, args.end()); - createSwap(builder, loc, swapOperands, nx, ny, isCoo); + createSwap(builder, loc, swapOperands, xPerm, ny); // If the pivot is moved, update p with the new pivot. Value icond = builder.create(loc, arith::CmpIPredicate::eq, i, p); @@ -737,8 +727,8 @@ static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, // } // static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP) { + func::FuncOp func, AffineMap xPerm, uint64_t ny, + uint32_t nTrailingP) { // The value n is passed in as a trailing parameter. assert(nTrailingP == 1); OpBuilder::InsertionGuard insertionGuard(builder); @@ -768,7 +758,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); Value c1 = constantIndex(builder, loc, 1); SmallVector compareOperands{start, start}; - uint64_t numXBuffers = isCoo ? 1 : nx; + constexpr uint64_t numXBuffers = 1; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); @@ -794,7 +784,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, compareOperands[0] = lChildIdx; compareOperands[1] = rChildIdx; Value cond2 = - createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); + createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); scf::IfOp if2 = builder.create(loc, ifTypes, cond2, /*else=*/true); builder.setInsertionPointToStart(&if2.getThenRegion().front()); @@ -825,8 +815,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, childIdx = before->getArgument(2); compareOperands[0] = start; compareOperands[1] = childIdx; - Value cond = - createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); + Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); builder.create(loc, cond, before->getArguments()); // The after-region of the WhileOp. @@ -836,7 +825,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, childIdx = after->getArgument(2); SmallVector swapOperands{start, childIdx}; swapOperands.append(args.begin() + xStartIdx, args.end()); - createSwap(builder, loc, swapOperands, nx, ny, isCoo); + createSwap(builder, loc, swapOperands, xPerm, ny); start = childIdx; Value cond2 = builder.create(loc, arith::CmpIPredicate::uge, t, child); @@ -869,8 +858,8 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, // shiftdown(lo, lo, l-1) // } static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP) { + func::FuncOp func, AffineMap xPerm, uint64_t ny, + uint32_t nTrailingP) { // Heap sort function doesn't have trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); @@ -897,7 +886,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, shiftDownOperands.append(args.begin() + xStartIdx, args.end()); shiftDownOperands.push_back(n); FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( - builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo, + builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); builder.create(loc, shiftDownFunc, TypeRange(), shiftDownOperands); @@ -912,7 +901,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, loplm1 = builder.create(loc, loplm1, c1); SmallVector swapOperands{lo, loplm1}; swapOperands.append(args.begin() + xStartIdx, args.end()); - createSwap(builder, loc, swapOperands, nx, ny, isCoo); + createSwap(builder, loc, swapOperands, xPerm, ny); shiftDownOperands[1] = lo; shiftDownOperands[shiftDownOperands.size() - 1] = builder.create(loc, l, c1); @@ -928,7 +917,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, /// the bigger partition to be processed by the enclosed while-loop. static std::pair createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, - ValueRange args, uint64_t nx, uint64_t ny, bool isCoo, + ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { MLIRContext *context = module.getContext(); Location loc = func.getLoc(); @@ -937,8 +926,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, SmallVector types(2, lo.getType()); // Only two types. FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( - builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, - ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc); + builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, + ny, args.drop_back(nTrailingP), createPartitionFunc); Value p = builder .create(loc, partitionFunc, TypeRange{IndexType::get(context)}, @@ -1008,8 +997,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, // } // } static void createSortStableFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP) { + func::FuncOp func, AffineMap xPerm, + uint64_t ny, uint32_t nTrailingP) { // Stable sort function doesn't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); @@ -1034,8 +1023,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, SmallVector operands{lo, i}; operands.append(args.begin() + xStartIdx, args.end()); FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( - builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx, - ny, isCoo, operands, createBinarySearchFunc); + builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, + xPerm, ny, operands, createBinarySearchFunc); Value p = builder .create(loc, searchFunc, TypeRange{c1.getType()}, operands) @@ -1045,7 +1034,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, operands[0] = operands[1] = i; SmallVector d; forEachIJPairInAllBuffers( - builder, loc, operands, nx, ny, isCoo, + builder, loc, operands, xPerm, ny, [&](uint64_t unused, Value i, Value unused2, Value buffer) { d.push_back(builder.create(loc, buffer, i)); }); @@ -1061,7 +1050,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, operands[1] = imj; operands[0] = builder.create(loc, imj, c1); forEachIJPairInAllBuffers( - builder, loc, operands, nx, ny, isCoo, + builder, loc, operands, xPerm, ny, [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { Value t = builder.create(loc, buffer, imjm1); builder.create(loc, t, buffer, imj); @@ -1071,7 +1060,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, builder.setInsertionPointAfter(forOpJ); operands[0] = operands[1] = p; forEachIJPairInAllBuffers( - builder, loc, operands, nx, ny, isCoo, + builder, loc, operands, xPerm, ny, [&](uint64_t k, Value p, Value usused, Value buffer) { builder.create(loc, d[k], buffer, p); }); @@ -1123,8 +1112,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, // } // static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP) { + func::FuncOp func, AffineMap xPerm, uint64_t ny, + uint32_t nTrailingP) { assert(nTrailingP == 1 || nTrailingP == 0); bool isHybrid = (nTrailingP == 1); OpBuilder::InsertionGuard insertionGuard(builder); @@ -1173,7 +1162,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, // When len <= limit. builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( - builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo, + builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, ValueRange(args).drop_back(nTrailingP), createSortStableFunc); builder.create(loc, insertionSortFunc, TypeRange(), ValueRange(args).drop_back(nTrailingP)); @@ -1193,7 +1182,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, // When depth exceeds limit. builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( - builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo, + builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); builder.create(loc, heapSortFunc, TypeRange(), ValueRange(args).drop_back(nTrailingP)); @@ -1203,7 +1192,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); args.back() = depthLimit; std::tie(lo, hi) = - createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP); + createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); builder.create(loc, ValueRange{lo, hi}); builder.setInsertionPointAfter(depthIf); @@ -1216,7 +1205,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, hi = lenIf.getResult(1); } else { std::tie(lo, hi) = - createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP); + createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); } // New [lo, hi) for the next while-loop iteration. @@ -1229,9 +1218,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, /// Implements the rewriting for operator sort and sort_coo. template -LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, - uint64_t ny, bool isCoo, - PatternRewriter &rewriter) { +LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, + uint64_t ny, PatternRewriter &rewriter) { Location loc = op.getLoc(); SmallVector operands{constantIndex(rewriter, loc, 0), op.getN()}; @@ -1285,8 +1273,8 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, } FlatSymbolRefAttr func = - getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, - ny, isCoo, operands, funcGenerator, nTrailingP); + getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, + xPerm, ny, operands, funcGenerator, nTrailingP); rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); return success(); } @@ -1296,7 +1284,6 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, //===---------------------------------------------------------------------===// namespace { - /// Sparse rewriting rule for the push_back operator. struct PushBackRewriter : OpRewritePattern { public: @@ -1410,20 +1397,6 @@ struct PushBackRewriter : OpRewritePattern { bool enableBufferInitialization; }; -/// Sparse rewriting rule for the sort operator. -struct SortRewriter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SortOp op, - PatternRewriter &rewriter) const override { - SmallVector xys(op.getXs()); - xys.append(op.getYs().begin(), op.getYs().end()); - return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0, - /*isCoo=*/false, rewriter); - } -}; - /// Sparse rewriting rule for the sort_coo operator. struct SortCooRewriter : public OpRewritePattern { public: @@ -1434,16 +1407,13 @@ struct SortCooRewriter : public OpRewritePattern { SmallVector xys; xys.push_back(op.getXy()); xys.append(op.getYs().begin(), op.getYs().end()); - uint64_t nx = 1; - if (auto nxAttr = op.getNxAttr()) - nx = nxAttr.getInt(); + auto xPerm = op.getPermMap(); uint64_t ny = 0; if (auto nyAttr = op.getNyAttr()) ny = nyAttr.getInt(); - return matchAndRewriteSortOp(op, xys, nx, ny, - /*isCoo=*/true, rewriter); + return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter); } }; @@ -1457,5 +1427,5 @@ void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization) { patterns.add(patterns.getContext(), enableBufferInitialization); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 557c5c471c4a7..4419c39c69927 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -890,8 +890,9 @@ class SparseCompressConverter : public OpConversionPattern { // If the innermost level is ordered, we need to sort the coordinates // in the "added" array prior to applying the compression. if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) - rewriter.create(loc, count, ValueRange{added}, ValueRange{}, - SparseTensorSortKind::HybridQuickSort); + rewriter.create( + loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1), + rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); // While performing the insertions, we also need to reset the elements // of the values/filled-switch by only iterating over the set elements, // to ensure that the runtime complexity remains proportional to the @@ -1486,9 +1487,10 @@ struct SparseNewOpConverter : public OpConversionPattern { scf::IfOp ifOp = rewriter.create(loc, notSorted, /*else*/ false); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - rewriter.create( - loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(lvlRank), - rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); + auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank); + rewriter.create(loc, nse, xs, ValueRange{ys}, xPerm, + rewriter.getIndexAttr(0), + SparseTensorSortKind::HybridQuickSort); rewriter.setInsertionPointAfter(ifOp); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index ca7d8a7850b0b..7d2f0c7f139cd 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -207,7 +207,6 @@ struct SparseTensorCodegenPass ConversionTarget target(*ctx); // Most ops in the sparse dialect must go! target.addIllegalDialect(); - target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); // Storage specifier outlives sparse tensor pipeline. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 47f7dad08c8c9..277903dc55b74 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1206,29 +1206,23 @@ struct ConvertRewriter : public OpRewritePattern { // Retrieve the values-array. Value y = genToValues(rewriter, loc, src); const auto encSrc = srcTp.getEncoding(); - // Sort the COO tensor so that its elements are ordered via increasing - // coordinates for the storage ordering of the dst tensor. Use SortCoo - // if the COO tensor has the same ordering as the dst tensor. - if (dimRank > 1 && srcTp.hasSameDimToLvl(dstTp)) { - Value xs = genToCoordinatesBuffer(rewriter, loc, src); - rewriter.create( - loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank), - rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); - } else { - // Gather the coordinates-arrays in the dst tensor storage order. - SmallVector xs(dstLvlRank); - const Level srcLvlRank = srcTp.getLvlRank(); - for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) { - // FIXME: `toOrigDim` is deprecated - Dimension dim = toOrigDim(encSrc, srcLvl); - // FIXME: `toStoredDim` is deprecated - Level dstLvl = toStoredDim(encDst, dim); - xs[dstLvl] = - genToCoordinates(rewriter, loc, src, srcLvl, /*cooStart=*/0); - } - rewriter.create(loc, nnz, xs, ValueRange{y}, - SparseTensorSortKind::HybridQuickSort); + // Builds the dstLvl -> srcLvl permutation maps. + SmallVector es(dstLvlRank); + const Level srcLvlRank = srcTp.getLvlRank(); + for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) { + // FIXME: `toOrigDim` is deprecated + Dimension dim = toOrigDim(encSrc, srcLvl); + // FIXME: `toStoredDim` is deprecated + Level dstLvl = toStoredDim(encDst, dim); + es[dstLvl] = rewriter.getAffineDimExpr(srcLvl); } + auto xPerm = AffineMap::get(dstLvlRank, 0, es, rewriter.getContext()); + assert(xPerm.isPermutation()); // must be a permutation. + + Value xs = genToCoordinatesBuffer(rewriter, loc, src); + rewriter.create(loc, nnz, xs, ValueRange{y}, xPerm, + rewriter.getIndexAttr(0), + SparseTensorSortKind::HybridQuickSort); } // For each element in the COO tensor, insert the element to the dst tensor. diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir index 0036bd5c3310b..c96a55aa1e8b2 100644 --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -75,123 +75,64 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f // ----- -// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index -// CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index -// CHECK-LABEL: func.func @sparse_sort_1d2v_quick -func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xindex>) - -> (memref<10xi8>, memref, memref<10xindex>) { - sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xi8>, memref, memref<10xindex> -} - -// ----- - -// Only check the generated supporting function now. We have integration test -// to verify correctness of the generated code. -// -// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-LABEL: func.func @sparse_sort_3d_quick -func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> -} - -// ----- - -// Only check the generated supporting function now. We have integration test -// to verify correctness of the generated code. -// -// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { -// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { -// CHECK-LABEL: func.func @sparse_sort_3d_hybrid -func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> -} - -// ----- - -// Only check the generated supporting functions. We have integration test to -// verify correctness of the generated code. -// -// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-LABEL: func.func @sparse_sort_3d_stable -func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> -} - -// ----- +#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { -// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-LABEL: func.func @sparse_sort_3d_heap -func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { - sparse_tensor.sort heap_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> - return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> -} - -// ----- - -// Only check the generated supporting functions. We have integration test to -// verify correctness of the generated code. -// -// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_quick func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } // ----- +#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> + // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { -// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { +// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { +// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_hybrid_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: i64) { // CHECK-LABEL: func.func @sparse_sort_coo_hybrid func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } // ----- +#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> + // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_stable func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } // ----- +#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> + // Only check the generated supporting functions. We have integration test to // verify correctness of the generated code. // -// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { -// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { +// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { // CHECK-LABEL: func.func @sparse_sort_coo_heap func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { - sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir index f1317f23d6568..ea11a98b76ec6 100644 --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -436,7 +436,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref { // CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index // CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]] // CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]]) // CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref // CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref @@ -484,7 +484,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, // CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64 // CHECK: %[[A12:.*]] = arith.constant 1 : index // CHECK: %[[A13:.*]] = arith.constant 0 : index -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]] // CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref // CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref @@ -712,7 +712,7 @@ func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) -> // CHECK: %[[A33:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A5]], %[[A32]], %[[A14]], %[[A15]]) // CHECK: %[[A34:.*]] = arith.cmpi eq, %[[A33]], %[[A1]] : i1 // CHECK: scf.if %[[A34]] { -// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {nx = 2 : index, ny = 0 : index} : memref jointly memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref jointly memref // CHECK: } // CHECK: memref.store %[[A10]], %[[A27]]{{\[}}%[[A2]]] : memref // CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A30]] crd_mem_sz at 0 with %[[A11]] diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir index b3eb50f1755da..54cdfc690952d 100644 --- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir @@ -178,7 +178,7 @@ func.func @sparse_convert_singleton(%arg0: tensor) -> // CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_17:.*]] hasInserts : tensor> // CHECK-RWT: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_16]] : tensor> to memref // CHECK-RWT: %[[VAL_19:.*]] = sparse_tensor.coordinates_buffer %[[VAL_16]] : tensor> to memref -// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {nx = 3 : index, ny = 0 : index} +// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {ny = 0 : index, perm_map = #map} // CHECK-RWT: %[[VAL_20:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) size_hint=%[[VAL_7]] // CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.foreach in %[[VAL_16]] init(%[[VAL_20]]) // CHECK-RWT: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: tensor>): diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index 71e6eebb30261..c0e813dcde7c5 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -790,60 +790,51 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> ( return } -// ----- - -// TODO: a test case with empty xs doesn't work due to some parser issues. - -func.func @sparse_sort_x_type( %arg0: index, %arg1: memref) { - // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1: memref -} - -// ----- - -func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) { - %i20 = arith.constant 20 : index - // expected-error@+1 {{xs and ys need to have a dimension >= n: 10 < 20}} - sparse_tensor.sort insertion_sort_stable %i20, %arg0 : memref<10xindex> - return -} // ----- -func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) { - // expected-error@+1 {{mismatch xs element types}} - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8> - return -} - -// ----- +#MAP = affine_map<(i,j) -> (i,j)> func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref) { // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} - sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1: memref + sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 {perm_map = #MAP} : memref return } // ----- +#MAP = affine_map<(i,j) -> (i,j)> + func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) { %i20 = arith.constant 20 : index - // expected-error@+1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}} - sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex> + // expected-error@+1 {{Expected dimension(xy) >= n * (rank(perm_map) + ny) got 50 < 60}} + sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {perm_map = #MAP, ny = 1 : index} : memref<50xindex> return } // ----- +#MAP = affine_map<(i,j) -> (i,j)> + func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) { %i20 = arith.constant 20 : index // expected-error@+1 {{Expected dimension(y) >= n got 10 < 20}} - sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32> + sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {perm_map = #MAP, ny = 1 : index} : memref<60xindex> jointly memref<10xf32> return } // ----- +#NON_PERM_MAP = affine_map<(i,j) -> (i,i)> + +func.func @sparse_sort_coo_no_perm(%arg0: index, %arg1: memref) -> (memref) { + // expected-error@+1 {{Expected a permutation map, got (d0, d1) -> (d0, d0)}} + sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #NON_PERM_MAP, ny = 1 : index}: memref + return %arg1 : memref +} + +// ----- + #CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}> func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> { diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index d1262cb7aea02..d252fa559a154 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -612,79 +612,29 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> ( // ----- -// CHECK-LABEL: func @sparse_sort_1d0v( -// CHECK-SAME: %[[A:.*]]: index, -// CHECK-SAME: %[[B:.*]]: memref) -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] : memref -// CHECK: return %[[B]] -func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref) -> (memref) { - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 : memref - return %arg1 : memref -} - -// ----- - -// CHECK-LABEL: func @sparse_sort_1d2v( -// CHECK-SAME: %[[A:.*]]: index, -// CHECK-SAME: %[[B:.*]]: memref<20xindex>, -// CHECK-SAME: %[[C:.*]]: memref<10xindex>, -// CHECK-SAME: %[[D:.*]]: memref) -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref -// CHECK: return %[[B]], %[[C]], %[[D]] -func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref) -> (memref<20xindex>, memref<10xindex>, memref) { - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref - return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref -} - -// ----- - -// CHECK-LABEL: func @sparse_sort_2d1v( -// CHECK-SAME: %[[A:.*]]: index, -// CHECK-SAME: %[[B:.*]]: memref<10xi8>, -// CHECK-SAME: %[[C:.*]]: memref<20xi8>, -// CHECK-SAME: %[[D:.*]]: memref<10xf64>) -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> -// CHECK: return %[[B]], %[[C]], %[[D]] -func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> - return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> -} - -// ----- - -// CHECK-LABEL: func @sparse_sort_stable( -// CHECK-SAME: %[[A:.*]]: index, -// CHECK-SAME: %[[B:.*]]: memref<10xi8>, -// CHECK-SAME: %[[C:.*]]: memref<20xi8>, -// CHECK-SAME: %[[D:.*]]: memref<10xf64>) -// CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> -// CHECK: return %[[B]], %[[C]], %[[D]] -func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { - sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> - return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> -} - -// ----- +#ID_MAP = affine_map<(i,j) -> (i,j)> // CHECK-LABEL: func @sparse_sort_coo( // CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref) -// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {nx = 2 : index, ny = 1 : index} : memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {ny = 1 : index, perm_map = #{{.*}}} : memref // CHECK: return %[[B]] func.func @sparse_sort_coo(%arg0: index, %arg1: memref) -> (memref) { - sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {nx = 2 : index, ny = 1 : index}: memref + sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #ID_MAP, ny = 1 : index}: memref return %arg1 : memref } // ----- +#ID_MAP = affine_map<(i,j) -> (i,j)> + // CHECK-LABEL: func @sparse_sort_coo_stable( // CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: memref) -// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {nx = 2 : index, ny = 1 : index} +// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {ny = 1 : index, perm_map = #{{.*}}} // CHECK: return %[[B]], %[[C]] func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref, %arg2: memref) -> (memref, memref) { - sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {nx = 2 : index, ny = 1 : index}: memref jointly memref + sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref jointly memref return %arg1, %arg2 : memref, memref } diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir index b31ac3ef3a254..5c308dc3c5623 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -116,7 +116,7 @@ // CHECK: } {"Emitted from" = "linalg.generic"} // CHECK: scf.yield %[[VAL_64:.*]] : index // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] : memref +// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] // CHECK: %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex> // CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir deleted file mode 100644 index 9e8ecad9cf282..0000000000000 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir +++ /dev/null @@ -1,187 +0,0 @@ -//-------------------------------------------------------------------------------------------------- -// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS. -// -// Set-up that's shared across all tests in this directory. In principle, this -// config could be moved to lit.local.cfg. However, there are downstream users that -// do not use these LIT config files. Hence why this is kept inline. -// -// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true -// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts} -// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}" -// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}" -// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils -// DEFINE: %{run_opts} = -e entry -entry-point-result=void -// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs} -// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs} -// -// DEFINE: %{env} = -//-------------------------------------------------------------------------------------------------- - -// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false -// RUN: %{compile} | %{run} | FileCheck %s -// -// Do the same run, but now with vectorization. -// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true -// RUN: %{compile} | %{run} | FileCheck %s -// -// Do the same run, but now with VLA vectorization. -// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %} - -module { - func.func private @printMemref1dI32(%ptr : memref) attributes { llvm.emit_c_interface } - - // Stores 5 values to the memref buffer. - func.func @storeValuesTo(%b: memref, %v0: i32, %v1: i32, %v2: i32, - %v3: i32, %v4: i32) -> () { - %i0 = arith.constant 0 : index - %i1 = arith.constant 1 : index - %i2 = arith.constant 2 : index - %i3 = arith.constant 3 : index - %i4 = arith.constant 4 : index - memref.store %v0, %b[%i0] : memref - memref.store %v1, %b[%i1] : memref - memref.store %v2, %b[%i2] : memref - memref.store %v3, %b[%i3] : memref - memref.store %v4, %b[%i4] : memref - return - } - - // The main driver. - func.func @entry() { - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - %c2 = arith.constant 2 : i32 - %c3 = arith.constant 3 : i32 - %c4 = arith.constant 4 : i32 - %c5 = arith.constant 5 : i32 - %c6 = arith.constant 6 : i32 - %c7 = arith.constant 7 : i32 - %c8 = arith.constant 8 : i32 - %c9 = arith.constant 9 : i32 - %c10 = arith.constant 10 : i32 - %c100 = arith.constant 100 : i32 - - %i0 = arith.constant 0 : index - %i4 = arith.constant 4 : index - %i5 = arith.constant 5 : index - - // Prepare a buffer. - %x0s = memref.alloc() : memref<5xi32> - %x0 = memref.cast %x0s : memref<5xi32> to memref - call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) - : (memref, i32, i32, i32, i32, i32) -> () - - // Sort 0 elements. - // Quick sort. - // CHECK: [10, 2, 0, 5, 1] - sparse_tensor.sort quick_sort %i0, %x0 : memref - call @printMemref1dI32(%x0) : (memref) -> () - // Stable sort. - // CHECK: [10, 2, 0, 5, 1] - sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref - call @printMemref1dI32(%x0) : (memref) -> () - // Heap sort. - // CHECK: [10, 2, 0, 5, 1] - sparse_tensor.sort heap_sort %i0, %x0 : memref - call @printMemref1dI32(%x0) : (memref) -> () - // Hybrid sort. - // CHECK: [10, 2, 0, 5, 1] - sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref - call @printMemref1dI32(%x0) : (memref) -> () - - // Sort the first 4 elements, with the last valid value untouched. - // Quick sort. - // CHECK: [0, 2, 5, 10, 1] - sparse_tensor.sort quick_sort %i4, %x0 : memref - call @printMemref1dI32(%x0) : (memref) -> () - // Stable sort. - // CHECK: [0, 2, 5, 10, 1] - call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) - : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref - call @printMemref1dI32(%x0) : (memref) -> () - // Heap sort. - // CHECK: [0, 2, 5, 10, 1] - call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1) - : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort heap_sort %i4, %x0 : memref - call @printMemref1dI32(%x0) : (memref) -> () - // Hybrid sort. - // CHECK: [0, 2, 5, 10, 1] - sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref - call @printMemref1dI32(%x0) : (memref) -> () - - // Prepare more buffers of different dimensions. - %x1s = memref.alloc() : memref<10xi32> - %x1 = memref.cast %x1s : memref<10xi32> to memref - %x2s = memref.alloc() : memref<6xi32> - %x2 = memref.cast %x2s : memref<6xi32> to memref - %y0s = memref.alloc() : memref<7xi32> - %y0 = memref.cast %y0s : memref<7xi32> to memref - - // Sort "parallel arrays". - // CHECK: [1, 1, 2, 5, 10] - // CHECK: [3, 3, 1, 10, 1 - // CHECK: [9, 9, 4, 7, 2 - // CHECK: [7, 8, 10, 9, 6 - call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1) - : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3) - : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9) - : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) - : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort quick_sort %i5, %x0, %x1, %x2 jointly %y0 - : memref, memref, memref jointly memref - call @printMemref1dI32(%x0) : (memref) -> () - call @printMemref1dI32(%x1) : (memref) -> () - call @printMemref1dI32(%x2) : (memref) -> () - call @printMemref1dI32(%y0) : (memref) -> () - // Stable sort. - // CHECK: [1, 1, 2, 5, 10] - // CHECK: [3, 3, 1, 10, 1 - // CHECK: [9, 9, 4, 7, 2 - // CHECK: [8, 7, 10, 9, 6 - call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1) - : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3) - : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9) - : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) - : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort insertion_sort_stable %i5, %x0, %x1, %x2 jointly %y0 - : memref, memref, memref jointly memref - call @printMemref1dI32(%x0) : (memref) -> () - call @printMemref1dI32(%x1) : (memref) -> () - call @printMemref1dI32(%x2) : (memref) -> () - call @printMemref1dI32(%y0) : (memref) -> () - // Heap sort. - // CHECK: [1, 1, 2, 5, 10] - // CHECK: [3, 3, 1, 10, 1 - // CHECK: [9, 9, 4, 7, 2 - // CHECK: [7, 8, 10, 9, 6 - call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1) - : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3) - : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9) - : (memref, i32, i32, i32, i32, i32) -> () - call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) - : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort heap_sort %i5, %x0, %x1, %x2 jointly %y0 - : memref, memref, memref jointly memref - call @printMemref1dI32(%x0) : (memref) -> () - call @printMemref1dI32(%x1) : (memref) -> () - call @printMemref1dI32(%x2) : (memref) -> () - call @printMemref1dI32(%y0) : (memref) -> () - - // Release the buffers. - memref.dealloc %x0 : memref - memref.dealloc %x1 : memref - memref.dealloc %x2 : memref - memref.dealloc %y0 : memref - return - } -} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir index ca5dd00d02aff..394b9a8448b54 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir @@ -28,6 +28,8 @@ // Do the same run, but now with VLA vectorization. // RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %} +#ID_MAP = affine_map<(d0, d1, d2) -> (d1, d2, d0)> + module { // Stores 5 values to the memref buffer. func.func @storeValuesTo(%b: memref, %v0: i32, %v1: i32, %v2: i32, @@ -94,11 +96,11 @@ module { %y1 = memref.cast %y1s : memref<7xi32> to memref // Sort "parallel arrays". - // CHECK: ( 1, 1, 3, 3, 10 ) - // CHECK: ( 2, 10, 1, 1, 5 ) - // CHECK: ( 4, 2, 9, 9, 7 ) - // CHECK: ( 10, 6, 7, 8, 9 ) - // CHECK: ( 7, 5, 7, 4, 9 ) + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 7, 8, 10, 9, 6 ) + // CHECK: ( 7, 4, 7, 9, 5 ) call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1) @@ -109,24 +111,25 @@ module { : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} + sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index} : memref jointly memref - %x0v = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> - vector.print %x0v : vector<5xi32> + // Dumps memory in the same order as the perm_map such that the output is ordered. %x1v = vector.transfer_read %x1[%i0], %c100: memref>, vector<5xi32> vector.print %x1v : vector<5xi32> %x2v = vector.transfer_read %x2[%i0], %c100: memref>, vector<5xi32> vector.print %x2v : vector<5xi32> + %x0v = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> + vector.print %x0v : vector<5xi32> %y0v = vector.transfer_read %y0[%i0], %c100: memref>, vector<5xi32> vector.print %y0v : vector<5xi32> %y1v = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> vector.print %y1v : vector<5xi32> // Stable sort. - // CHECK: ( 1, 1, 3, 3, 10 ) - // CHECK: ( 2, 10, 1, 1, 5 ) - // CHECK: ( 4, 2, 9, 9, 7 ) - // CHECK: ( 10, 6, 8, 7, 9 ) - // CHECK: ( 7, 5, 4, 7, 9 ) + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 8, 7, 10, 9, 6 ) + // CHECK: ( 4, 7, 7, 9, 5 ) call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1) @@ -137,24 +140,24 @@ module { : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} + sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index} : memref jointly memref - %x0v2 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> - vector.print %x0v2 : vector<5xi32> %x1v2 = vector.transfer_read %x1[%i0], %c100: memref>, vector<5xi32> vector.print %x1v2 : vector<5xi32> %x2v2 = vector.transfer_read %x2[%i0], %c100: memref>, vector<5xi32> vector.print %x2v2 : vector<5xi32> + %x0v2 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> + vector.print %x0v2 : vector<5xi32> %y0v2 = vector.transfer_read %y0[%i0], %c100: memref>, vector<5xi32> vector.print %y0v2 : vector<5xi32> %y1v2 = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32> vector.print %y1v2 : vector<5xi32> // Heap sort. - // CHECK: ( 1, 1, 3, 3, 10 ) - // CHECK: ( 2, 10, 1, 1, 5 ) - // CHECK: ( 4, 2, 9, 9, 7 ) - // CHECK: ( 10, 6, 8, 7, 9 ) - // CHECK: ( 7, 5, 4, 7, 9 ) + // CHECK: ( 1, 1, 2, 5, 10 ) + // CHECK: ( 9, 9, 4, 7, 2 ) + // CHECK: ( 3, 3, 1, 10, 1 ) + // CHECK: ( 7, 8, 10, 9, 6 ) + // CHECK: ( 7, 4, 7, 9, 5 ) call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3) : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1) @@ -165,14 +168,14 @@ module { : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} + sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index} : memref jointly memref - %x0v3 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> - vector.print %x0v3 : vector<5xi32> %x1v3 = vector.transfer_read %x1[%i0], %c100: memref>, vector<5xi32> vector.print %x1v3 : vector<5xi32> %x2v3 = vector.transfer_read %x2[%i0], %c100: memref>, vector<5xi32> vector.print %x2v3 : vector<5xi32> + %x0v3 = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> + vector.print %x0v3 : vector<5xi32> %y0v3 = vector.transfer_read %y0[%i0], %c100: memref>, vector<5xi32> vector.print %y0v3 : vector<5xi32> %y1v3 = vector.transfer_read %y1[%i0], %c100: memref, vector<5xi32>