@@ -33,6 +33,15 @@ using namespace mlir::sparse_tensor;
3333
3434namespace {
3535
36+ // Sparse formats supported by cuSparse.
37+ enum class CuSparseFormat {
38+ kNone ,
39+ kCOO ,
40+ kCSR ,
41+ kCSC ,
42+ kBSR , // TODO: coming soon!
43+ };
44+
3645// ===----------------------------------------------------------------------===//
3746// Helper methods.
3847// ===----------------------------------------------------------------------===//
@@ -385,73 +394,92 @@ static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
385394 return false ;
386395}
387396
388- // / Determines if the given value is a dense tensor instead of a sparse one .
397+ // / Test for dense tensor.
389398static bool isDenseTensor (Value v) {
390- return (sparse_tensor::getSparseTensorType (v).isAllDense ());
399+ auto sTp = getSparseTensorType (v);
400+ return sTp .getDimRank () == sTp .getLvlRank () && sTp .isAllDense ();
401+ }
402+
403+ // / Test for suitable positions/coordinates width.
404+ static bool isAdmissibleMetaData (SparseTensorType &aTp) {
405+ return (aTp.getPosWidth () == 0 || aTp.getPosWidth () >= 16 ) &&
406+ (aTp.getCrdWidth () == 0 || aTp.getCrdWidth () >= 16 );
391407}
392408
393- // / Test for sorted COO with suitable data and coordinates types .
409+ // / Test for sorted COO matrix with suitable metadata .
394410static bool isAdmissibleCOO (SparseTensorType &aTp) {
395- return aTp.isCompressedLvl (0 ) && aTp.isOrderedLvl (0 ) && !aTp.isUniqueLvl (0 ) &&
411+ return aTp.getDimRank () == 2 && aTp.getLvlRank () == 2 && aTp.isIdentity () &&
412+ aTp.isCompressedLvl (0 ) && aTp.isOrderedLvl (0 ) && !aTp.isUniqueLvl (0 ) &&
396413 aTp.isSingletonLvl (1 ) && aTp.isOrderedLvl (1 ) && aTp.isUniqueLvl (1 ) &&
397- (aTp.getElementType ().isF64 () || aTp.getElementType ().isF32 ()) &&
398- (aTp.getCrdWidth () == 0 || aTp.getCrdWidth () == 32 ||
399- aTp.getCrdWidth () == 64 );
414+ isAdmissibleMetaData (aTp);
400415}
401416
402- // / Test for CSR with suitable data and coordinates types .
417+ // / Test for CSR matrix with suitable metadata .
403418static bool isAdmissibleCSR (SparseTensorType &aTp) {
404- return aTp.isDenseLvl (0 ) && aTp.isCompressedLvl (1 ) && aTp.isOrderedLvl (1 ) &&
405- aTp.isUniqueLvl (1 ) &&
406- (aTp.getElementType ().isF64 () || aTp.getElementType ().isF32 ()) &&
407- (aTp.getCrdWidth () == 0 || aTp.getCrdWidth () == 32 ||
408- aTp.getCrdWidth () == 64 );
419+ return aTp.getDimRank () == 2 && aTp.getLvlRank () == 2 && aTp.isIdentity () &&
420+ aTp.isDenseLvl (0 ) && aTp.isCompressedLvl (1 ) && aTp.isOrderedLvl (1 ) &&
421+ aTp.isUniqueLvl (1 ) && isAdmissibleMetaData (aTp);
409422}
410423
411- // / Test for admissible types on operands (with output parameter `isCOO`).
412- static bool areAdmissibleTypes (SparseTensorType aTp, SparseTensorType bTp,
413- SparseTensorType cTp, bool enableRT,
414- bool isMatVec, bool &isCOO) {
424+ // / Test for CSC matrix with suitable metadata.
425+ static bool isAdmissibleCSC (SparseTensorType &aTp) {
426+ return aTp.getDimRank () == 2 && aTp.getLvlRank () == 2 && !aTp.isIdentity () &&
427+ aTp.isPermutation () && aTp.isDenseLvl (0 ) && aTp.isCompressedLvl (1 ) &&
428+ aTp.isOrderedLvl (1 ) && aTp.isUniqueLvl (1 ) && isAdmissibleMetaData (aTp);
429+ }
430+
431+ // / Returns a suitable sparse format for the operation and given operand
432+ // / types with cuSparse, or kNone if none is available.
433+ static CuSparseFormat getCuSparseFormat (SparseTensorType aTp,
434+ SparseTensorType bTp,
435+ SparseTensorType cTp, bool enableRT,
436+ bool isMatVec) {
437+ // The other operands have a dense type.
415438 if (bTp.hasEncoding () || cTp.hasEncoding ())
416- return false ;
417- if ( isAdmissibleCOO (aTp)) {
418- isCOO = true ;
439+ return CuSparseFormat:: kNone ;
440+ // Now check for suitable operand type for the main operand.
441+ if ( isAdmissibleCOO (aTp))
419442#ifdef CUSPARSE_COO_AOS
420- return isMatVec;
443+ return isMatVec ? CuSparseFormat:: kCOO : CuSparseFormat:: kNone ;
421444#else
422- return enableRT;
445+ return enableRT ? CuSparseFormat:: kCOO : CuSparseFormat:: kNone ;
423446#endif
424- }
425- return isAdmissibleCSR (aTp);
447+ if (isAdmissibleCSR (aTp))
448+ return CuSparseFormat::kCSR ;
449+ if (isAdmissibleCSC (aTp))
450+ return CuSparseFormat::kCSC ;
451+ return CuSparseFormat::kNone ;
426452}
427453
428454// / Generates the first positions/coordinates of a sparse matrix.
429455static Value genFirstPosOrCrds (OpBuilder &builder, Location loc, Value a,
430- bool isCOO , bool enableRT) {
431- if (isCOO ) {
456+ CuSparseFormat format , bool enableRT) {
457+ if (format == CuSparseFormat:: kCOO ) {
432458 // Library uses SoA COO, direct IR uses AoS COO.
433459 if (enableRT)
434460 return genToCoordinates (builder, loc, a, 0 , /* cooStart=*/ 0 );
435461 return genToCoordinatesBuffer (builder, loc, a);
436462 }
437- // CSR uses positions.
463+ // Formats CSR/CSC and BSR use positions at 1 .
438464 return genToPositions (builder, loc, a, 1 );
439465}
440466
441467// / Generates the second coordinates of a sparse matrix.
442468static Value genSecondCrds (OpBuilder &builder, Location loc, Value a,
443- bool isCOO, bool enableRT) {
469+ CuSparseFormat format, bool enableRT) {
470+ bool isCOO = format == CuSparseFormat::kCOO ;
444471 if (isCOO && !enableRT)
445472 return Value (); // nothing needed
473+ // Formats CSR/CSC and BSR use coordinates at 1.
446474 return genToCoordinates (builder, loc, a, 1 , /* cooStart=*/ isCOO ? 0 : 2 );
447475}
448476
449- // / Generates the sparse matrix multiplication .
477+ // / Generates the sparse matrix handle .
450478static Operation *genSpMat (OpBuilder &builder, Location loc, Type handleTp,
451479 Type tokenTp, Value token, Value sz1, Value sz2,
452480 Value nseA, Value rowA, Value colA, Value valA,
453- bool isCOO , bool enableRT) {
454- if (isCOO ) {
481+ CuSparseFormat format , bool enableRT) {
482+ if (format == CuSparseFormat:: kCOO ) {
455483 // Library uses SoA COO, direct IR uses AoS COO.
456484 if (enableRT) {
457485 assert (colA);
@@ -467,7 +495,11 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
467495#endif
468496 }
469497 assert (colA);
470- return builder.create <gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
498+ if (format == CuSparseFormat::kCSR )
499+ return builder.create <gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
500+ sz2, nseA, rowA, colA, valA);
501+ assert (format == CuSparseFormat::kCSC );
502+ return builder.create <gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
471503 sz2, nseA, rowA, colA, valA);
472504}
473505
@@ -484,12 +516,12 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
484516 bool isZeroCopy =
485517 gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy ;
486518
487- // Only admissible sparse matrix format and dense vectors.
488- bool isCOO = false ;
519+ // Only admissible sparse matrix format and dense vectors (no BSR).
489520 SparseTensorType aTp = getSparseTensorType (a);
490521 SparseTensorType xTp = getSparseTensorType (x);
491522 SparseTensorType yTp = getSparseTensorType (y);
492- if (!areAdmissibleTypes (aTp, xTp, yTp, enableRT, /* isMatVec=*/ true , isCOO))
523+ auto format = getCuSparseFormat (aTp, xTp, yTp, enableRT, /* isMatVec=*/ true );
524+ if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR )
493525 return failure ();
494526
495527 // Start sparse kernel and copy data from host to device.
@@ -499,8 +531,8 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
499531 Value nseA = rewriter.create <NumberOfEntriesOp>(loc, a);
500532 Value szY = linalg::createOrFoldDimOp (rewriter, loc, a, 0 );
501533 Value szX = linalg::createOrFoldDimOp (rewriter, loc, a, 1 );
502- Value memR = genFirstPosOrCrds (rewriter, loc, a, isCOO , enableRT);
503- Value memC = genSecondCrds (rewriter, loc, a, isCOO , enableRT);
534+ Value memR = genFirstPosOrCrds (rewriter, loc, a, format , enableRT);
535+ Value memC = genSecondCrds (rewriter, loc, a, format , enableRT);
504536 Value memV = genToValues (rewriter, loc, a);
505537 Value memX, memY;
506538 Value castR, castC, castV, castX, castY;
@@ -535,7 +567,7 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
535567 Value token = genFirstWait (rewriter, loc);
536568 Operation *spGenA =
537569 genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
538- rowA, colA, valA, isCOO , enableRT);
570+ rowA, colA, valA, format , enableRT);
539571 Value spMatA = spGenA->getResult (0 );
540572 token = spGenA->getResult (1 );
541573 auto dvecX = rewriter.create <gpu::CreateDnTensorOp>(
@@ -546,7 +578,6 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
546578 loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
547579 Value dnY = dvecY.getResult (0 );
548580 token = dvecY.getAsyncToken ();
549-
550581 auto dnYType = llvm::cast<ShapedType>(y.getType ()).getElementType ();
551582
552583 // Precompute buffersize for SpMV.
@@ -610,12 +641,12 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
610641 bool isZeroCopy =
611642 gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy ;
612643
613- // Only admissible sparse matrix format and dense matrices.
614- bool isCOO = false ;
644+ // Only admissible sparse matrix format and dense matrices (no BSR).
615645 SparseTensorType aTp = getSparseTensorType (a);
616646 SparseTensorType bTp = getSparseTensorType (b);
617647 SparseTensorType cTp = getSparseTensorType (c);
618- if (!areAdmissibleTypes (aTp, bTp, cTp, enableRT, /* isMatVec=*/ false , isCOO))
648+ auto format = getCuSparseFormat (aTp, bTp, cTp, enableRT, /* isMatVec=*/ false );
649+ if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR )
619650 return failure ();
620651
621652 // Start sparse kernel and copy data from host to device.
@@ -626,8 +657,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
626657 Value szm = linalg::createOrFoldDimOp (rewriter, loc, a, 0 );
627658 Value szk = linalg::createOrFoldDimOp (rewriter, loc, a, 1 );
628659 Value szn = linalg::createOrFoldDimOp (rewriter, loc, b, 1 );
629- Value memR = genFirstPosOrCrds (rewriter, loc, a, isCOO , enableRT);
630- Value memC = genSecondCrds (rewriter, loc, a, isCOO , enableRT);
660+ Value memR = genFirstPosOrCrds (rewriter, loc, a, format , enableRT);
661+ Value memC = genSecondCrds (rewriter, loc, a, format , enableRT);
631662 Value memV = genToValues (rewriter, loc, a);
632663 Value bufB, bufC;
633664 Value castR, castC, castV, castB, castBufC;
@@ -661,7 +692,7 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
661692 Value token = genFirstWait (rewriter, loc);
662693 Operation *spGenA =
663694 genSpMat (rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
664- rowA, colA, valA, isCOO , enableRT);
695+ rowA, colA, valA, format , enableRT);
665696 Value spMatA = spGenA->getResult (0 );
666697 token = spGenA->getResult (1 );
667698 auto dmatB = rewriter.create <gpu::CreateDnTensorOp>(
@@ -674,7 +705,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
674705 SmallVector<Value>{szm, szn});
675706 Value dnC = dmatC.getResult (0 );
676707 token = dmatC.getAsyncToken ();
677-
678708 auto dmatCType = llvm::cast<ShapedType>(c.getType ()).getElementType ();
679709
680710 // Precompute buffersize for SpMM.
@@ -686,7 +716,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
686716 auto buf = genAllocBuffer (rewriter, loc, bufferSz, token);
687717 Value buffer = buf.getResult (0 );
688718 token = buf.getAsyncToken ();
689-
690719 auto dnCType = llvm::cast<ShapedType>(c.getType ()).getElementType ();
691720
692721 // Perform the SpMM.
@@ -738,7 +767,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
738767 SmallVector<Value> tokens;
739768
740769 // Only CSR <- CSR x CSR supported.
741- bool isCOO = false ;
770+ auto format = CuSparseFormat:: kCSR ;
742771 SparseTensorType aTp = getSparseTensorType (a);
743772 SparseTensorType bTp = getSparseTensorType (b);
744773 SparseTensorType cTp = getSparseTensorType (c);
@@ -755,11 +784,11 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
755784 Value szm = linalg::createOrFoldDimOp (rewriter, loc, a, 0 );
756785 Value szk = linalg::createOrFoldDimOp (rewriter, loc, a, 1 );
757786 Value szn = linalg::createOrFoldDimOp (rewriter, loc, b, 1 );
758- Value amemR = genFirstPosOrCrds (rewriter, loc, a, isCOO , enableRT);
759- Value amemC = genSecondCrds (rewriter, loc, a, isCOO , enableRT);
787+ Value amemR = genFirstPosOrCrds (rewriter, loc, a, format , enableRT);
788+ Value amemC = genSecondCrds (rewriter, loc, a, format , enableRT);
760789 Value amemV = genToValues (rewriter, loc, a);
761- Value bmemR = genFirstPosOrCrds (rewriter, loc, b, isCOO , enableRT);
762- Value bmemC = genSecondCrds (rewriter, loc, b, isCOO , enableRT);
790+ Value bmemR = genFirstPosOrCrds (rewriter, loc, b, format , enableRT);
791+ Value bmemC = genSecondCrds (rewriter, loc, b, format , enableRT);
763792 Value bmemV = genToValues (rewriter, loc, b);
764793 Value rowA = genAllocCopy (rewriter, loc, amemR, tokens);
765794 Value colA = genAllocCopy (rewriter, loc, amemC, tokens);
@@ -778,12 +807,12 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
778807 Value token = genFirstWait (rewriter, loc);
779808 Operation *spGenA =
780809 genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
781- rowA, colA, valA, isCOO , enableRT);
810+ rowA, colA, valA, format , enableRT);
782811 Value spMatA = spGenA->getResult (0 );
783812 token = spGenA->getResult (1 );
784813 Operation *spGenB =
785814 genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
786- rowB, colB, valB, isCOO , enableRT);
815+ rowB, colB, valB, format , enableRT);
787816 Value spMatB = spGenB->getResult (0 );
788817 token = spGenB->getResult (1 );
789818
@@ -802,7 +831,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
802831 token = e3 .getAsyncToken ();
803832 Operation *spGenC =
804833 genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
805- rowC, colC, valC, isCOO , enableRT);
834+ rowC, colC, valC, format , enableRT);
806835 Value spMatC = spGenC->getResult (0 );
807836 token = spGenC->getResult (1 );
808837
@@ -1046,14 +1075,13 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
10461075 bool isZeroCopy =
10471076 gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy ;
10481077
1049- // Only admissible sparse matrix format and dense matrices, no COO.
1050- bool isCOO = false ;
1078+ // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
10511079 SparseTensorType aTp = getSparseTensorType (a);
10521080 SparseTensorType bTp = getSparseTensorType (b);
10531081 SparseTensorType cTp = getSparseTensorType (c);
1054- if (! areAdmissibleTypes ( cTp, bTp, aTp, enableRT, false , isCOO))
1055- return failure ();
1056- if (isCOO )
1082+ auto format = getCuSparseFormat ( cTp, bTp, aTp, enableRT, /* isMatVec= */ false );
1083+ if (format == CuSparseFormat:: kNone || format == CuSparseFormat:: kCOO ||
1084+ format == CuSparseFormat:: kCSC )
10571085 return failure ();
10581086
10591087 // The SDDMM does the in-place operation.
@@ -1072,8 +1100,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
10721100 Value bufB = genTensorToMemref (rewriter, loc, b);
10731101 if (!isZeroCopy)
10741102 matB = isZeroCopy ? bufB : genAllocCopy (rewriter, loc, bufB, tokens);
1075- Value memR = genFirstPosOrCrds (rewriter, loc, c, isCOO , enableRT);
1076- Value memC = genSecondCrds (rewriter, loc, c, isCOO , enableRT);
1103+ Value memR = genFirstPosOrCrds (rewriter, loc, c, format , enableRT);
1104+ Value memC = genSecondCrds (rewriter, loc, c, format , enableRT);
10771105 Value memV = genToValues (rewriter, loc, c);
10781106 Value castB, castA, castR, castC, castV;
10791107 if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA ) {
@@ -1108,10 +1136,9 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
11081136 loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
11091137 Value dnB = dmatB.getResult (0 );
11101138 token = dmatB.getAsyncToken ();
1111-
11121139 Operation *spGenC =
11131140 genSpMat (rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
1114- rowC, colC, valC, isCOO , enableRT);
1141+ rowC, colC, valC, format , enableRT);
11151142 Value spMatC = spGenC->getResult (0 );
11161143 token = spGenC->getResult (1 );
11171144 auto dnCType = llvm::cast<ShapedType>(c.getType ()).getElementType ();
0 commit comments