@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
3131
3232namespace {
3333
34+ static constexpr uint64_t DimSizesIdx = 0 ;
35+ static constexpr uint64_t DimCursorIdx = 1 ;
36+ static constexpr uint64_t MemSizesIdx = 2 ;
37+ static constexpr uint64_t FieldsIdx = 3 ;
38+
3439// ===----------------------------------------------------------------------===//
3540// Helper methods.
3641// ===----------------------------------------------------------------------===//
@@ -90,11 +95,17 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
9095 .getResult ();
9196}
9297
98+ // / Translates field index to memSizes index.
99+ static unsigned getMemSizesIndex (unsigned field) {
100+ assert (FieldsIdx <= field);
101+ return field - FieldsIdx;
102+ }
103+
93104// / Returns field index of sparse tensor type for pointers/indices, when set.
94105static unsigned getFieldIndex (Type type, unsigned ptrDim, unsigned idxDim) {
95106 assert (getSparseTensorEncoding (type));
96107 RankedTensorType rType = type.cast <RankedTensorType>();
97- unsigned field = 2 ; // start past sizes
108+ unsigned field = FieldsIdx ; // start past header
98109 unsigned ptr = 0 ;
99110 unsigned idx = 0 ;
100111 for (unsigned r = 0 , rank = rType.getShape ().size (); r < rank; r++) {
@@ -140,6 +151,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
140151 //
141152 // struct {
142153 // memref<rank x index> dimSizes ; size in each dimension
154+ // memref<rank x index> dimCursor ; cursor in each dimension
143155 // memref<n x index> memSizes ; sizes of ptrs/inds/values
144156 // ; per-dimension d:
145157 // ; if dense:
@@ -153,11 +165,11 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
153165 // };
154166 //
155167 unsigned rank = rType.getShape ().size ();
156- // The dimSizes array.
157- fields.push_back (MemRefType::get ({rank}, indexType));
158- // The memSizes array.
159168 unsigned lastField = getFieldIndex (type, -1u , -1u );
160- fields.push_back (MemRefType::get ({lastField - 2 }, indexType));
169+ // The dimSizes array, dimCursor array, and memSizes array.
170+ fields.push_back (MemRefType::get ({rank}, indexType));
171+ fields.push_back (MemRefType::get ({rank}, indexType));
172+ fields.push_back (MemRefType::get ({getMemSizesIndex (lastField)}, indexType));
161173 // Per-dimension storage.
162174 for (unsigned r = 0 ; r < rank; r++) {
163175 // Dimension level types apply in order to the reordered dimension.
@@ -179,7 +191,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
179191 return success ();
180192}
181193
182- // / Create allocation operation.
194+ // / Creates allocation operation.
183195static Value createAllocation (OpBuilder &builder, Location loc, Type type,
184196 Value sz) {
185197 auto memType = MemRefType::get ({ShapedType::kDynamicSize }, type);
@@ -220,14 +232,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
220232 else
221233 sizes.push_back (constantIndex (builder, loc, shape[r]));
222234 }
223- // The dimSizes array.
235+ // The dimSizes array, dimCursor array, and memSizes array.
236+ unsigned lastField = getFieldIndex (type, -1u , -1u );
224237 Value dimSizes =
225238 builder.create <memref::AllocOp>(loc, MemRefType::get ({rank}, indexType));
226- fields.push_back (dimSizes);
227- // The sizes array.
228- unsigned lastField = getFieldIndex (type, -1u , -1u );
239+ Value dimCursor =
240+ builder.create <memref::AllocOp>(loc, MemRefType::get ({rank}, indexType));
229241 Value memSizes = builder.create <memref::AllocOp>(
230- loc, MemRefType::get ({lastField - 2 }, indexType));
242+ loc, MemRefType::get ({getMemSizesIndex (lastField)}, indexType));
243+ fields.push_back (dimSizes);
244+ fields.push_back (dimCursor);
231245 fields.push_back (memSizes);
232246 // Per-dimension storage.
233247 for (unsigned r = 0 ; r < rank; r++) {
@@ -277,23 +291,17 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
277291 return forOp;
278292}
279293
280- // / Translates field index to memSizes index.
281- static unsigned getMemSizesIndex (unsigned field) {
282- assert (2 <= field);
283- return field - 2 ;
284- }
285-
286294// / Creates a pushback op for given field and updates the fields array
287295// / accordingly.
288296static void createPushback (OpBuilder &builder, Location loc,
289297 SmallVectorImpl<Value> &fields, unsigned field,
290298 Value value) {
291- assert (2 <= field && field < fields.size ());
299+ assert (FieldsIdx <= field && field < fields.size ());
292300 Type etp = fields[field].getType ().cast <ShapedType>().getElementType ();
293301 if (value.getType () != etp)
294302 value = builder.create <arith::IndexCastOp>(loc, etp, value);
295303 fields[field] = builder.create <PushBackOp>(
296- loc, fields[field].getType (), fields[1 ], fields[field], value,
304+ loc, fields[field].getType (), fields[MemSizesIdx ], fields[field], value,
297305 APInt (64 , getMemSizesIndex (field)));
298306}
299307
@@ -312,8 +320,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
312320 return ; // TODO: add codegen
313321 // push_back memSizes indices-0 index
314322 // push_back memSizes values value
315- createPushback (builder, loc, fields, 3 , indices[0 ]);
316- createPushback (builder, loc, fields, 4 , value);
323+ createPushback (builder, loc, fields, FieldsIdx + 1 , indices[0 ]);
324+ createPushback (builder, loc, fields, FieldsIdx + 2 , value);
317325}
318326
319327// / Generations insertion finalization code.
@@ -329,9 +337,9 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
329337 // push_back memSizes pointers-0 memSizes[2]
330338 Value zero = constantIndex (builder, loc, 0 );
331339 Value two = constantIndex (builder, loc, 2 );
332- Value size = builder.create <memref::LoadOp>(loc, fields[1 ], two);
333- createPushback (builder, loc, fields, 2 , zero);
334- createPushback (builder, loc, fields, 2 , size);
340+ Value size = builder.create <memref::LoadOp>(loc, fields[MemSizesIdx ], two);
341+ createPushback (builder, loc, fields, FieldsIdx , zero);
342+ createPushback (builder, loc, fields, FieldsIdx , size);
335343}
336344
337345// ===----------------------------------------------------------------------===//
@@ -759,7 +767,7 @@ class SparseNumberOfEntriesConverter
759767 unsigned lastField = fields.size () - 1 ;
760768 Value field =
761769 constantIndex (rewriter, op.getLoc (), getMemSizesIndex (lastField));
762- rewriter.replaceOpWithNewOp <memref::LoadOp>(op, fields[1 ], field);
770+ rewriter.replaceOpWithNewOp <memref::LoadOp>(op, fields[MemSizesIdx ], field);
763771 return success ();
764772 }
765773};
0 commit comments