Skip to content

Commit d392073

Browse files
authored
[mlir][sparse] simplify reader construction of new sparse tensor (llvm#69036)
Making the materialize-from-reader method part of the Swiss army knife suite again removes a lot of redundant boiler plate code and unifies the parameter setup into a single centralized utility. Furthermore, we now have minimized the number of entry points into the library that need a non-permutation map setup, simplifying what comes next
1 parent 182a65a commit d392073

File tree

5 files changed

+31
-195
lines changed

5 files changed

+31
-195
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ enum class Action : uint32_t {
146146
kEmptyForward = 1,
147147
kFromCOO = 2,
148148
kSparseToSparse = 3,
149+
kFromReader = 4,
149150
kToCOO = 5,
150151
kPack = 7,
151152
kSortCOOInPlace = 8,

mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,6 @@ MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_createCheckedSparseTensorReader(
115115
char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
116116
PrimaryType valTp);
117117

118-
/// Constructs a new sparse-tensor storage object with the given encoding,
119-
/// initializes it by reading all the elements from the file, and then
120-
/// closes the file.
121-
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensorFromReader(
122-
void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
123-
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
124-
StridedMemRefType<index_type, 1> *dim2lvlRef,
125-
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
126-
OverheadType crdTp, PrimaryType valTp);
127-
128118
/// SparseTensorReader method to obtain direct access to the
129119
/// dimension-sizes array.
130120
MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes(
@@ -197,24 +187,9 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO)
197187
/// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
198188
MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);
199189

200-
/// Helper function to read the header of a file and return the
201-
/// shape/sizes, without parsing the elements of the file.
202-
MLIR_CRUNNERUTILS_EXPORT void readSparseTensorShape(char *filename,
203-
std::vector<uint64_t> *out);
204-
205-
/// Returns the rank of the sparse tensor being read.
206-
MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderRank(void *p);
207-
208-
/// Returns the is_symmetric bit for the sparse tensor being read.
209-
MLIR_CRUNNERUTILS_EXPORT bool getSparseTensorReaderIsSymmetric(void *p);
210-
211190
/// Returns the number of stored elements for the sparse tensor being read.
212191
MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderNSE(void *p);
213192

214-
/// Returns the size of a dimension for the sparse tensor being read.
215-
MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderDimSize(void *p,
216-
index_type d);
217-
218193
/// Releases the SparseTensorReader and closes the associated file.
219194
MLIR_CRUNNERUTILS_EXPORT void delSparseTensorReader(void *p);
220195

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,15 @@ class NewCallParams final {
199199
/// type-level information such as the encoding and sizes), generating
200200
/// MLIR buffers as needed, and returning `this` for method chaining.
201201
NewCallParams &genBuffers(SparseTensorType stt,
202-
ArrayRef<Value> dimSizesValues) {
202+
ArrayRef<Value> dimSizesValues,
203+
Value dimSizesBuffer = Value()) {
203204
assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
204205
// Sparsity annotations.
205206
params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
206207
// Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
207-
params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
208+
params[kParamDimSizes] = dimSizesBuffer
209+
? dimSizesBuffer
210+
: allocaBuffer(builder, loc, dimSizesValues);
208211
params[kParamLvlSizes] =
209212
genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
210213
params[kParamDim2Lvl], params[kParamLvl2Dim]);
@@ -342,33 +345,15 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
342345
const auto stt = getSparseTensorType(op);
343346
if (!stt.hasEncoding())
344347
return failure();
345-
// Construct the reader opening method calls.
348+
// Construct the `reader` opening method calls.
346349
SmallVector<Value> dimShapesValues;
347350
Value dimSizesBuffer;
348351
Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
349352
dimShapesValues, dimSizesBuffer);
350-
// Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
351-
Value dim2lvlBuffer;
352-
Value lvl2dimBuffer;
353-
Value lvlSizesBuffer =
354-
genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
355-
dim2lvlBuffer, lvl2dimBuffer);
356353
// Use the `reader` to parse the file.
357-
Type opaqueTp = getOpaquePointerType(rewriter);
358-
Type eltTp = stt.getElementType();
359-
Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
360-
SmallVector<Value, 8> params{
361-
reader,
362-
lvlSizesBuffer,
363-
genLvlTypesBuffer(rewriter, loc, stt),
364-
dim2lvlBuffer,
365-
lvl2dimBuffer,
366-
constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
367-
constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
368-
valTp};
369-
Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
370-
opaqueTp, params, EmitCInterface::On)
371-
.getResult(0);
354+
Value tensor = NewCallParams(rewriter, loc)
355+
.genBuffers(stt, dimShapesValues, dimSizesBuffer)
356+
.genNewCall(Action::kFromReader, reader);
372357
// Free the memory for `reader`.
373358
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
374359
EmitCInterface::Off);

mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp

Lines changed: 6 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ extern "C" {
138138
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
139139
dimRank, tensor); \
140140
} \
141+
case Action::kFromReader: { \
142+
assert(ptr && "Received nullptr for SparseTensorReader object"); \
143+
SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \
144+
return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
145+
lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
146+
} \
141147
case Action::kToCOO: { \
142148
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
143149
auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
@@ -442,113 +448,6 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
442448
MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
443449
#undef IMPL_GETNEXT
444450

445-
void *_mlir_ciface_newSparseTensorFromReader(
446-
void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
447-
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
448-
StridedMemRefType<index_type, 1> *dim2lvlRef,
449-
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
450-
OverheadType crdTp, PrimaryType valTp) {
451-
assert(p);
452-
SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
453-
ASSERT_NO_STRIDE(lvlSizesRef);
454-
ASSERT_NO_STRIDE(lvlTypesRef);
455-
ASSERT_NO_STRIDE(dim2lvlRef);
456-
ASSERT_NO_STRIDE(lvl2dimRef);
457-
const uint64_t dimRank = reader.getRank();
458-
const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
459-
ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
460-
ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
461-
ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
462-
(void)dimRank;
463-
const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
464-
const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
465-
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
466-
const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
467-
#define CASE(p, c, v, P, C, V) \
468-
if (posTp == OverheadType::p && crdTp == OverheadType::c && \
469-
valTp == PrimaryType::v) \
470-
return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
471-
lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));
472-
#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
473-
// Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
474-
// This is safe because of the static_assert above.
475-
if (posTp == OverheadType::kIndex)
476-
posTp = OverheadType::kU64;
477-
if (crdTp == OverheadType::kIndex)
478-
crdTp = OverheadType::kU64;
479-
// Double matrices with all combinations of overhead storage.
480-
CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
481-
CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
482-
CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
483-
CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
484-
CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
485-
CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
486-
CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
487-
CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
488-
CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
489-
CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
490-
CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
491-
CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
492-
CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
493-
CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
494-
CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
495-
CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
496-
// Float matrices with all combinations of overhead storage.
497-
CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
498-
CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
499-
CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
500-
CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
501-
CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
502-
CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
503-
CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
504-
CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
505-
CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
506-
CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
507-
CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
508-
CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
509-
CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
510-
CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
511-
CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
512-
CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
513-
// Two-byte floats with both overheads of the same type.
514-
CASE_SECSAME(kU64, kF16, uint64_t, f16);
515-
CASE_SECSAME(kU64, kBF16, uint64_t, bf16);
516-
CASE_SECSAME(kU32, kF16, uint32_t, f16);
517-
CASE_SECSAME(kU32, kBF16, uint32_t, bf16);
518-
CASE_SECSAME(kU16, kF16, uint16_t, f16);
519-
CASE_SECSAME(kU16, kBF16, uint16_t, bf16);
520-
CASE_SECSAME(kU8, kF16, uint8_t, f16);
521-
CASE_SECSAME(kU8, kBF16, uint8_t, bf16);
522-
// Integral matrices with both overheads of the same type.
523-
CASE_SECSAME(kU64, kI64, uint64_t, int64_t);
524-
CASE_SECSAME(kU64, kI32, uint64_t, int32_t);
525-
CASE_SECSAME(kU64, kI16, uint64_t, int16_t);
526-
CASE_SECSAME(kU64, kI8, uint64_t, int8_t);
527-
CASE_SECSAME(kU32, kI64, uint32_t, int64_t);
528-
CASE_SECSAME(kU32, kI32, uint32_t, int32_t);
529-
CASE_SECSAME(kU32, kI16, uint32_t, int16_t);
530-
CASE_SECSAME(kU32, kI8, uint32_t, int8_t);
531-
CASE_SECSAME(kU16, kI64, uint16_t, int64_t);
532-
CASE_SECSAME(kU16, kI32, uint16_t, int32_t);
533-
CASE_SECSAME(kU16, kI16, uint16_t, int16_t);
534-
CASE_SECSAME(kU16, kI8, uint16_t, int8_t);
535-
CASE_SECSAME(kU8, kI64, uint8_t, int64_t);
536-
CASE_SECSAME(kU8, kI32, uint8_t, int32_t);
537-
CASE_SECSAME(kU8, kI16, uint8_t, int16_t);
538-
CASE_SECSAME(kU8, kI8, uint8_t, int8_t);
539-
// Complex matrices with wide overhead.
540-
CASE_SECSAME(kU64, kC64, uint64_t, complex64);
541-
CASE_SECSAME(kU64, kC32, uint64_t, complex32);
542-
543-
// Unsupported case (add above if needed).
544-
MLIR_SPARSETENSOR_FATAL(
545-
"unsupported combination of types: <P=%d, C=%d, V=%d>\n",
546-
static_cast<int>(posTp), static_cast<int>(crdTp),
547-
static_cast<int>(valTp));
548-
#undef CASE_SECSAME
549-
#undef CASE
550-
}
551-
552451
void _mlir_ciface_outSparseTensorWriterMetaData(
553452
void *p, index_type dimRank, index_type nse,
554453
StridedMemRefType<index_type, 1> *dimSizesRef) {
@@ -635,34 +534,10 @@ char *getTensorFilename(index_type id) {
635534
return env;
636535
}
637536

638-
void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
639-
assert(out && "Received nullptr for out-parameter");
640-
SparseTensorReader reader(filename);
641-
reader.openFile();
642-
reader.readHeader();
643-
reader.closeFile();
644-
const uint64_t dimRank = reader.getRank();
645-
const uint64_t *dimSizes = reader.getDimSizes();
646-
out->reserve(dimRank);
647-
out->assign(dimSizes, dimSizes + dimRank);
648-
}
649-
650-
index_type getSparseTensorReaderRank(void *p) {
651-
return static_cast<SparseTensorReader *>(p)->getRank();
652-
}
653-
654-
bool getSparseTensorReaderIsSymmetric(void *p) {
655-
return static_cast<SparseTensorReader *>(p)->isSymmetric();
656-
}
657-
658537
index_type getSparseTensorReaderNSE(void *p) {
659538
return static_cast<SparseTensorReader *>(p)->getNSE();
660539
}
661540

662-
index_type getSparseTensorReaderDimSize(void *p, index_type d) {
663-
return static_cast<SparseTensorReader *>(p)->getDimSize(d);
664-
}
665-
666541
void delSparseTensorReader(void *p) {
667542
delete static_cast<SparseTensorReader *>(p);
668543
}

mlir/test/Dialect/SparseTensor/conversion.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
7878
// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
7979
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
8080
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
81-
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
82-
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
8381
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
8482
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
85-
// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
83+
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
84+
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
85+
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
8686
// CHECK: call @delSparseTensorReader(%[[Reader]])
8787
// CHECK: return %[[T]] : !llvm.ptr<i8>
8888
func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
@@ -96,11 +96,11 @@ func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector>
9696
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
9797
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
9898
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
99-
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
100-
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
10199
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
102100
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
103-
// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
101+
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
102+
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
103+
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
104104
// CHECK: call @delSparseTensorReader(%[[Reader]])
105105
// CHECK: return %[[T]] : !llvm.ptr<i8>
106106
func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
@@ -114,15 +114,15 @@ func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
114114
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
115115
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
116116
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
117-
// CHECK: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
118-
// CHECK: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
119-
// CHECK: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
120-
// CHECK: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
121-
// CHECK: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
122-
// CHECK: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
123-
// CHECK: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
124-
// CHECK: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
125-
// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}})
117+
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
118+
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
119+
// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
120+
// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
121+
// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
122+
// CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
123+
// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
124+
// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
125+
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
126126
// CHECK: call @delSparseTensorReader(%[[Reader]])
127127
// CHECK: return %[[T]] : !llvm.ptr<i8>
128128
func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {

0 commit comments

Comments
 (0)