Skip to content

Commit 0d1f957

Browse files
authored
[mlir][sparse] support type conversion from batched sparse tensors to… (#83163)
… memrefs.
1 parent a76c524 commit 0d1f957

File tree

7 files changed

+69
-17
lines changed

7 files changed

+69
-17
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ struct LevelType {
342342
/// Check if the `LevelType` needs coordinates array.
343343
constexpr bool isWithCrdLT() const {
344344
// All sparse levels has coordinate array.
345-
return !isa<LevelFormat::Dense>();
345+
return !isa<LevelFormat::Dense, LevelFormat::Batch>();
346346
}
347347

348348
std::string toMLIRString() const {

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
374374
/// is non-null (since no fixed result is valid for every dense-tensor).
375375
::mlir::sparse_tensor::Level getLvlRank() const;
376376

377+
uint64_t getBatchLvlRank() const;
378+
377379
//
378380
// lvlTypes methods.
379381
//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ namespace sparse_tensor {
3030
/// ; if dense:
3131
/// <nothing>
3232
/// ; if compressed:
33-
/// memref<? x pos> positions ; positions for level l
34-
/// memref<? x crd> coordinates ; coordinates for level l
35-
/// ; if loose-compressed:
36-
/// memref<? x pos> positions ; lo/hi position pairs for level l
37-
/// memref<? x crd> coordinates ; coordinates for level l
33+
/// memref<[batch] x ? x pos> positions ; positions for level l
34+
/// memref<[batch] x ? x crd> coordinates ; coordinates for level l
35+
/// ; if loose-[batch] x compressed:
36+
/// memref<[batch] x ? x pos> positions ; lo/hi pos pairs for level l
37+
/// memref<[batch] x ? x crd> coordinates ; coordinates for level l
3838
/// ; if singleton/2-out-of-4:
39-
/// memref<? x crd> coordinates ; coordinates for level l
39+
/// memref<[batch] x ? x crd> coordinates ; coordinates for level l
4040
///
41-
/// memref<? x eltType> values ; values
41+
/// memref<[batch] x ? x eltType> values ; values
4242
///
4343
/// struct sparse_tensor.storage_specifier {
4444
/// array<rank x int> lvlSizes ; sizes/cardinalities for each level

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ class SparseTensorType {
253253
CrdTransDirectionKind::dim2lvl);
254254
}
255255

256+
/// Returns the Level-shape.
257+
SmallVector<Size> getBatchLvlShape() const {
258+
auto lvlShape = getEncoding().tranlateShape(getDimShape(),
259+
CrdTransDirectionKind::dim2lvl);
260+
lvlShape.truncate(getEncoding().getBatchLvlRank());
261+
return lvlShape;
262+
}
263+
256264
/// Returns the type with an identity mapping.
257265
RankedTensorType getDemappedType() const {
258266
return RankedTensorType::get(getLvlShape(), getElementType(),

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,16 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
126126
const Type posType = stt.getPosType();
127127
const Type eltType = stt.getElementType();
128128

129+
SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
130+
memrefShape.push_back(ShapedType::kDynamic);
131+
129132
const Type specType = StorageSpecifierType::get(stt.getEncoding());
130-
// memref<? x pos> positions
131-
const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
132-
// memref<? x crd> coordinates
133-
const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
134-
// memref<? x eltType> values
135-
const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
133+
// memref<[batch] x ? x pos> positions
134+
const Type posMemType = MemRefType::get(memrefShape, posType);
135+
// memref<[batch] x ? x crd> coordinates
136+
const Type crdMemType = MemRefType::get(memrefShape, crdType);
137+
// memref<[batch] x ? x eltType> values
138+
const Type valMemType = MemRefType::get(memrefShape, eltType);
136139

137140
StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
138141
callback](FieldIndex fieldIdx,
@@ -336,6 +339,12 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
336339
return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
337340
}
338341

342+
uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
343+
ArrayRef<LevelType> lvlTypes = getLvlTypes();
344+
auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
345+
return std::distance(lastBatch, lvlTypes.rend());
346+
}
347+
339348
bool SparseTensorEncodingAttr::isAllDense() const {
340349
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
341350
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
12931293
Value tensor = fKind == SparseTensorFieldKind::ValMemRef
12941294
? op.getValues()
12951295
: op.getLevels()[fIdx];
1296-
1296+
// TODO: handle batch.
12971297
TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
12981298
if (mem.getType().getRank() > 1) {
12991299
// Flattens the buffer to rank 1.
@@ -1322,9 +1322,8 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
13221322
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
13231323
assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
13241324

1325-
// FIXME: dim/lvl confusion!
13261325
// Sets up the level size.
1327-
auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
1326+
auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
13281327
desc.setLvlSize(rewriter, loc, lvl, lvlSize);
13291328
// We use a single AOS array to store the trailing COO, so there is only
13301329
// one memory size to set for the entire COO section.

mlir/test/Dialect/SparseTensor/codegen.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
map = (d0, d1) -> (d1 : dense, d0 : compressed)
3535
}>
3636

37+
#BCSR = #sparse_tensor.encoding<{
38+
map = (d0, d1, d2, d3) -> (d0: batch, d1: batch, d2 : dense, d3 : compressed)
39+
}>
40+
3741
#DCSR = #sparse_tensor.encoding<{
3842
map = (d0, d1) -> (d0 : compressed, d1 : compressed),
3943
crdWidth = 64,
@@ -182,6 +186,36 @@ func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
182186
return
183187
}
184188

189+
// CHECK-LABEL: func @sparse_bcsr_0(
190+
// CHECK-SAME: %[[A1:.*0]]: memref<?x2x?xindex>,
191+
// CHECK-SAME: %[[A2:.*1]]: memref<?x2x?xindex>,
192+
// CHECK-SAME: %[[A3:.*]]: memref<?x2x?xf64>,
193+
// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
194+
// CHECK: return
195+
func.func @sparse_bcsr_0(%arg0: tensor<?x2x?x?xf64, #BCSR>) {
196+
return
197+
}
198+
199+
// CHECK-LABEL: func @sparse_bcsr_1(
200+
// CHECK-SAME: %[[A1:.*0]]: memref<?x?x?xindex>,
201+
// CHECK-SAME: %[[A2:.*1]]: memref<?x?x?xindex>,
202+
// CHECK-SAME: %[[A3:.*]]: memref<?x?x?xf64>,
203+
// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
204+
// CHECK: return
205+
func.func @sparse_bcsr_1(%arg0: tensor<?x?x?x?xf64, #BCSR>) {
206+
return
207+
}
208+
209+
// CHECK-LABEL: func @sparse_bcsr_2(
210+
// CHECK-SAME: %[[A1:.*0]]: memref<18x6x?xindex>,
211+
// CHECK-SAME: %[[A2:.*1]]: memref<18x6x?xindex>,
212+
// CHECK-SAME: %[[A3:.*]]: memref<18x6x?xf64>,
213+
// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
214+
// CHECK: return
215+
func.func @sparse_bcsr_2(%arg0: tensor<18x6x4x2xf64, #BCSR>) {
216+
return
217+
}
218+
185219
// CHECK-LABEL: func @sparse_dcsr(
186220
// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
187221
// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,

0 commit comments

Comments
 (0)