diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index cc134e7d953ec..9e79b6aca1c9b 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -342,7 +342,7 @@ struct LevelType { /// Check if the `LevelType` needs coordinates array. constexpr bool isWithCrdLT() const { // All sparse levels has coordinate array. - return !isa(); + return !isa(); } std::string toMLIRString() const { diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index ca98665256be5..5d1db2323f95f 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -374,6 +374,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", /// is non-null (since no fixed result is valid for every dense-tensor). ::mlir::sparse_tensor::Level getLvlRank() const; + uint64_t getBatchLvlRank() const; + // // lvlTypes methods. // diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h index 27dc39609cdad..ce34ae43d1c18 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h @@ -30,15 +30,15 @@ namespace sparse_tensor { /// ; if dense: /// /// ; if compressed: -/// memref positions ; positions for level l -/// memref coordinates ; coordinates for level l -/// ; if loose-compressed: -/// memref positions ; lo/hi position pairs for level l -/// memref coordinates ; coordinates for level l +/// memref<[batch] x ? x pos> positions ; positions for level l +/// memref<[batch] x ? x crd> coordinates ; coordinates for level l +/// ; if loose-[batch] x compressed: +/// memref<[batch] x ? x pos> positions ; lo/hi pos pairs for level l +/// memref<[batch] x ? x crd> coordinates ; coordinates for level l /// ; if singleton/2-out-of-4: -/// memref coordinates ; coordinates for level l +/// memref<[batch] x ? x crd> coordinates ; coordinates for level l /// -/// memref values ; values +/// memref<[batch] x ? x eltType> values ; values /// /// struct sparse_tensor.storage_specifier { /// array lvlSizes ; sizes/cardinalities for each level diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h index 1a090ddb782fd..c93a4fcd922c2 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -253,6 +253,14 @@ class SparseTensorType { CrdTransDirectionKind::dim2lvl); } + /// Returns the Level-shape. + SmallVector getBatchLvlShape() const { + auto lvlShape = getEncoding().tranlateShape(getDimShape(), + CrdTransDirectionKind::dim2lvl); + lvlShape.truncate(getEncoding().getBatchLvlRank()); + return lvlShape; + } + /// Returns the type with an identity mapping. RankedTensorType getDemappedType() const { return RankedTensorType::get(getLvlShape(), getElementType(), diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index fd0ed26fbde07..69c3413f35ea9 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -126,13 +126,16 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor( const Type posType = stt.getPosType(); const Type eltType = stt.getElementType(); + SmallVector memrefShape = stt.getBatchLvlShape(); + memrefShape.push_back(ShapedType::kDynamic); + const Type specType = StorageSpecifierType::get(stt.getEncoding()); - // memref positions - const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType); - // memref coordinates - const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType); - // memref values - const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); + // memref<[batch] x ? x pos> positions + const Type posMemType = MemRefType::get(memrefShape, posType); + // memref<[batch] x ? x crd> coordinates + const Type crdMemType = MemRefType::get(memrefShape, crdType); + // memref<[batch] x ? x eltType> values + const Type valMemType = MemRefType::get(memrefShape, eltType); StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, callback](FieldIndex fieldIdx, @@ -336,6 +339,12 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { return withDimSlices(ArrayRef{}); } +uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { + ArrayRef lvlTypes = getLvlTypes(); + auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); + return std::distance(lastBatch, lvlTypes.rend()); +} + bool SparseTensorEncodingAttr::isAllDense() const { return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 0ccb11f3a6b85..d5eec4ae67e79 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1293,7 +1293,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern { Value tensor = fKind == SparseTensorFieldKind::ValMemRef ? op.getValues() : op.getLevels()[fIdx]; - + // TODO: handle batch. TypedValue mem = genToMemref(rewriter, loc, tensor); if (mem.getType().getRank() > 1) { // Flattens the buffer to rank 1. @@ -1322,9 +1322,8 @@ struct SparseAssembleOpConverter : public OpConversionPattern { for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { assert(!ShapedType::isDynamic(stt.getDimShape()[lvl])); - // FIXME: dim/lvl confusion! // Sets up the level size. - auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]); + auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]); desc.setLvlSize(rewriter, loc, lvl, lvlSize); // We use a single AOS array to store the trailing COO, so there is only // one memory size to set for the entire COO section. diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir index c1a976c84fecc..b63762485c961 100644 --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -34,6 +34,10 @@ map = (d0, d1) -> (d1 : dense, d0 : compressed) }> +#BCSR = #sparse_tensor.encoding<{ + map = (d0, d1, d2, d3) -> (d0: batch, d1: batch, d2 : dense, d3 : compressed) +}> + #DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed), crdWidth = 64, @@ -182,6 +186,36 @@ func.func @sparse_csr(%arg0: tensor) { return } +// CHECK-LABEL: func @sparse_bcsr_0( +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier +// CHECK: return +func.func @sparse_bcsr_0(%arg0: tensor) { + return +} + +// CHECK-LABEL: func @sparse_bcsr_1( +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier +// CHECK: return +func.func @sparse_bcsr_1(%arg0: tensor) { + return +} + +// CHECK-LABEL: func @sparse_bcsr_2( +// CHECK-SAME: %[[A1:.*0]]: memref<18x6x?xindex>, +// CHECK-SAME: %[[A2:.*1]]: memref<18x6x?xindex>, +// CHECK-SAME: %[[A3:.*]]: memref<18x6x?xf64>, +// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier +// CHECK: return +func.func @sparse_bcsr_2(%arg0: tensor<18x6x4x2xf64, #BCSR>) { + return +} + // CHECK-LABEL: func @sparse_dcsr( // CHECK-SAME: %[[A0:.*0]]: memref, // CHECK-SAME: %[[A1:.*1]]: memref,