@@ -138,6 +138,12 @@ extern "C" {
138
138
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
139
139
dimRank, tensor); \
140
140
} \
141
+ case Action::kFromReader : { \
142
+ assert (ptr && " Received nullptr for SparseTensorReader object" ); \
143
+ SparseTensorReader &reader = *static_cast <SparseTensorReader *>(ptr); \
144
+ return static_cast <void *>(reader.readSparseTensor <P, C, V>( \
145
+ lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
146
+ } \
141
147
case Action::kToCOO : { \
142
148
assert (ptr && " Received nullptr for SparseTensorStorage object" ); \
143
149
auto &tensor = *static_cast <SparseTensorStorage<P, C, V> *>(ptr); \
@@ -442,113 +448,6 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
442
448
MLIR_SPARSETENSOR_FOREVERY_V_O (IMPL_GETNEXT)
443
449
#undef IMPL_GETNEXT
444
450
445
- void *_mlir_ciface_newSparseTensorFromReader (
446
- void *p, StridedMemRefType<index_type, 1 > *lvlSizesRef,
447
- StridedMemRefType<DimLevelType, 1 > *lvlTypesRef,
448
- StridedMemRefType<index_type, 1 > *dim2lvlRef,
449
- StridedMemRefType<index_type, 1 > *lvl2dimRef, OverheadType posTp,
450
- OverheadType crdTp, PrimaryType valTp) {
451
- assert (p);
452
- SparseTensorReader &reader = *static_cast <SparseTensorReader *>(p);
453
- ASSERT_NO_STRIDE (lvlSizesRef);
454
- ASSERT_NO_STRIDE (lvlTypesRef);
455
- ASSERT_NO_STRIDE (dim2lvlRef);
456
- ASSERT_NO_STRIDE (lvl2dimRef);
457
- const uint64_t dimRank = reader.getRank ();
458
- const uint64_t lvlRank = MEMREF_GET_USIZE (lvlSizesRef);
459
- ASSERT_USIZE_EQ (lvlTypesRef, lvlRank);
460
- ASSERT_USIZE_EQ (dim2lvlRef, dimRank);
461
- ASSERT_USIZE_EQ (lvl2dimRef, lvlRank);
462
- (void )dimRank;
463
- const index_type *lvlSizes = MEMREF_GET_PAYLOAD (lvlSizesRef);
464
- const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD (lvlTypesRef);
465
- const index_type *dim2lvl = MEMREF_GET_PAYLOAD (dim2lvlRef);
466
- const index_type *lvl2dim = MEMREF_GET_PAYLOAD (lvl2dimRef);
467
- #define CASE (p, c, v, P, C, V ) \
468
- if (posTp == OverheadType::p && crdTp == OverheadType::c && \
469
- valTp == PrimaryType::v) \
470
- return static_cast <void *>(reader.readSparseTensor <P, C, V>( \
471
- lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));
472
- #define CASE_SECSAME (p, v, P, V ) CASE(p, p, v, P, P, V)
473
- // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
474
- // This is safe because of the static_assert above.
475
- if (posTp == OverheadType::kIndex )
476
- posTp = OverheadType::kU64 ;
477
- if (crdTp == OverheadType::kIndex )
478
- crdTp = OverheadType::kU64 ;
479
- // Double matrices with all combinations of overhead storage.
480
- CASE (kU64 , kU64 , kF64 , uint64_t , uint64_t , double );
481
- CASE (kU64 , kU32 , kF64 , uint64_t , uint32_t , double );
482
- CASE (kU64 , kU16 , kF64 , uint64_t , uint16_t , double );
483
- CASE (kU64 , kU8 , kF64 , uint64_t , uint8_t , double );
484
- CASE (kU32 , kU64 , kF64 , uint32_t , uint64_t , double );
485
- CASE (kU32 , kU32 , kF64 , uint32_t , uint32_t , double );
486
- CASE (kU32 , kU16 , kF64 , uint32_t , uint16_t , double );
487
- CASE (kU32 , kU8 , kF64 , uint32_t , uint8_t , double );
488
- CASE (kU16 , kU64 , kF64 , uint16_t , uint64_t , double );
489
- CASE (kU16 , kU32 , kF64 , uint16_t , uint32_t , double );
490
- CASE (kU16 , kU16 , kF64 , uint16_t , uint16_t , double );
491
- CASE (kU16 , kU8 , kF64 , uint16_t , uint8_t , double );
492
- CASE (kU8 , kU64 , kF64 , uint8_t , uint64_t , double );
493
- CASE (kU8 , kU32 , kF64 , uint8_t , uint32_t , double );
494
- CASE (kU8 , kU16 , kF64 , uint8_t , uint16_t , double );
495
- CASE (kU8 , kU8 , kF64 , uint8_t , uint8_t , double );
496
- // Float matrices with all combinations of overhead storage.
497
- CASE (kU64 , kU64 , kF32 , uint64_t , uint64_t , float );
498
- CASE (kU64 , kU32 , kF32 , uint64_t , uint32_t , float );
499
- CASE (kU64 , kU16 , kF32 , uint64_t , uint16_t , float );
500
- CASE (kU64 , kU8 , kF32 , uint64_t , uint8_t , float );
501
- CASE (kU32 , kU64 , kF32 , uint32_t , uint64_t , float );
502
- CASE (kU32 , kU32 , kF32 , uint32_t , uint32_t , float );
503
- CASE (kU32 , kU16 , kF32 , uint32_t , uint16_t , float );
504
- CASE (kU32 , kU8 , kF32 , uint32_t , uint8_t , float );
505
- CASE (kU16 , kU64 , kF32 , uint16_t , uint64_t , float );
506
- CASE (kU16 , kU32 , kF32 , uint16_t , uint32_t , float );
507
- CASE (kU16 , kU16 , kF32 , uint16_t , uint16_t , float );
508
- CASE (kU16 , kU8 , kF32 , uint16_t , uint8_t , float );
509
- CASE (kU8 , kU64 , kF32 , uint8_t , uint64_t , float );
510
- CASE (kU8 , kU32 , kF32 , uint8_t , uint32_t , float );
511
- CASE (kU8 , kU16 , kF32 , uint8_t , uint16_t , float );
512
- CASE (kU8 , kU8 , kF32 , uint8_t , uint8_t , float );
513
- // Two-byte floats with both overheads of the same type.
514
- CASE_SECSAME (kU64 , kF16 , uint64_t , f16 );
515
- CASE_SECSAME (kU64 , kBF16 , uint64_t , bf16 );
516
- CASE_SECSAME (kU32 , kF16 , uint32_t , f16 );
517
- CASE_SECSAME (kU32 , kBF16 , uint32_t , bf16 );
518
- CASE_SECSAME (kU16 , kF16 , uint16_t , f16 );
519
- CASE_SECSAME (kU16 , kBF16 , uint16_t , bf16 );
520
- CASE_SECSAME (kU8 , kF16 , uint8_t , f16 );
521
- CASE_SECSAME (kU8 , kBF16 , uint8_t , bf16 );
522
- // Integral matrices with both overheads of the same type.
523
- CASE_SECSAME (kU64 , kI64 , uint64_t , int64_t );
524
- CASE_SECSAME (kU64 , kI32 , uint64_t , int32_t );
525
- CASE_SECSAME (kU64 , kI16 , uint64_t , int16_t );
526
- CASE_SECSAME (kU64 , kI8 , uint64_t , int8_t );
527
- CASE_SECSAME (kU32 , kI64 , uint32_t , int64_t );
528
- CASE_SECSAME (kU32 , kI32 , uint32_t , int32_t );
529
- CASE_SECSAME (kU32 , kI16 , uint32_t , int16_t );
530
- CASE_SECSAME (kU32 , kI8 , uint32_t , int8_t );
531
- CASE_SECSAME (kU16 , kI64 , uint16_t , int64_t );
532
- CASE_SECSAME (kU16 , kI32 , uint16_t , int32_t );
533
- CASE_SECSAME (kU16 , kI16 , uint16_t , int16_t );
534
- CASE_SECSAME (kU16 , kI8 , uint16_t , int8_t );
535
- CASE_SECSAME (kU8 , kI64 , uint8_t , int64_t );
536
- CASE_SECSAME (kU8 , kI32 , uint8_t , int32_t );
537
- CASE_SECSAME (kU8 , kI16 , uint8_t , int16_t );
538
- CASE_SECSAME (kU8 , kI8 , uint8_t , int8_t );
539
- // Complex matrices with wide overhead.
540
- CASE_SECSAME (kU64 , kC64 , uint64_t , complex64);
541
- CASE_SECSAME (kU64 , kC32 , uint64_t , complex32);
542
-
543
- // Unsupported case (add above if needed).
544
- MLIR_SPARSETENSOR_FATAL (
545
- " unsupported combination of types: <P=%d, C=%d, V=%d>\n " ,
546
- static_cast <int >(posTp), static_cast <int >(crdTp),
547
- static_cast <int >(valTp));
548
- #undef CASE_SECSAME
549
- #undef CASE
550
- }
551
-
552
451
void _mlir_ciface_outSparseTensorWriterMetaData (
553
452
void *p, index_type dimRank, index_type nse,
554
453
StridedMemRefType<index_type, 1 > *dimSizesRef) {
@@ -635,34 +534,10 @@ char *getTensorFilename(index_type id) {
635
534
return env;
636
535
}
637
536
638
- void readSparseTensorShape (char *filename, std::vector<uint64_t > *out) {
639
- assert (out && " Received nullptr for out-parameter" );
640
- SparseTensorReader reader (filename);
641
- reader.openFile ();
642
- reader.readHeader ();
643
- reader.closeFile ();
644
- const uint64_t dimRank = reader.getRank ();
645
- const uint64_t *dimSizes = reader.getDimSizes ();
646
- out->reserve (dimRank);
647
- out->assign (dimSizes, dimSizes + dimRank);
648
- }
649
-
650
- index_type getSparseTensorReaderRank (void *p) {
651
- return static_cast <SparseTensorReader *>(p)->getRank ();
652
- }
653
-
654
- bool getSparseTensorReaderIsSymmetric (void *p) {
655
- return static_cast <SparseTensorReader *>(p)->isSymmetric ();
656
- }
657
-
658
537
index_type getSparseTensorReaderNSE (void *p) {
659
538
return static_cast <SparseTensorReader *>(p)->getNSE ();
660
539
}
661
540
662
- index_type getSparseTensorReaderDimSize (void *p, index_type d) {
663
- return static_cast <SparseTensorReader *>(p)->getDimSize (d);
664
- }
665
-
666
541
void delSparseTensorReader (void *p) {
667
542
delete static_cast <SparseTensorReader *>(p);
668
543
}
0 commit comments