Skip to content

Commit 481bd5d

Browse files
author
Peiming Liu
authored
[mlir][sparse] introduce sparse_tensor.extract_iteration_space operation. (#88554)
A `sparse_tensor.extract_space %tensor at %iterator` extracts a *sparse* iteration space defined `%tensor`, the operation to traverse the iteration space will be introduced in following PRs.
1 parent b955653 commit 481bd5d

File tree

5 files changed

+374
-0
lines changed

5 files changed

+374
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

+60
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,66 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
14301430
let hasVerifier = 1;
14311431
}
14321432

1433+
//===----------------------------------------------------------------------===//
1434+
// Sparse Tensor Iteration Operations.
1435+
//===----------------------------------------------------------------------===//
1436+
1437+
def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1438+
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1439+
1440+
let arguments = (ins AnySparseTensor:$tensor,
1441+
Optional<AnySparseIterator>:$parentIter,
1442+
LevelAttr:$loLvl, LevelAttr:$hiLvl);
1443+
1444+
let results = (outs AnySparseIterSpace:$resultSpace);
1445+
1446+
let summary = "Extracts an iteration space from a sparse tensor between certain levels";
1447+
let description = [{
1448+
Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
1449+
certain (consecutive) levels. For sparse levels, it is usually done by
1450+
loading a postion range from the underlying sparse tensor storage.
1451+
E.g., for a compressed level, the iteration space is extracted by
1452+
[pos[i], pos[i+1]) supposing the the parent iterator points at `i`.
1453+
1454+
`tensor`: the input sparse tensor that defines the iteration space.
1455+
`parentIter`: the iterator for the previous level, at which the iteration space
1456+
at the current levels will be extracted.
1457+
`loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
1458+
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
1459+
iteration space.
1460+
1461+
The type of returned the value is automatically inferred to
1462+
`!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
1463+
The returned iteration space can then be iterated over by
1464+
`sparse_tensor.iterate` operations to visit every stored element
1465+
(usually nonzeros) in the input sparse tensor.
1466+
1467+
Example:
1468+
```mlir
1469+
// Extracts a 1-D iteration space from a COO tensor at level 1.
1470+
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
1471+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1472+
```
1473+
}];
1474+
1475+
1476+
let extraClassDeclaration = [{
1477+
std::pair<Level, Level> getLvlRange() {
1478+
return std::make_pair(getLoLvl(), getHiLvl());
1479+
}
1480+
unsigned getSpaceDim() {
1481+
return getHiLvl() - getLoLvl();
1482+
}
1483+
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1484+
return getResultSpace().getType().getLvlTypes();
1485+
}
1486+
}];
1487+
1488+
let hasVerifier = 1;
1489+
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1490+
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1491+
}
1492+
14331493
//===----------------------------------------------------------------------===//
14341494
// Sparse Tensor Debugging and Test-Only Operations.
14351495
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td

+97
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,101 @@ def SparseTensorStorageSpecifier
7272
: Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
7373
"::mlir::sparse_tensor::StorageSpecifierType">;
7474

75+
//===----------------------------------------------------------------------===//
76+
// Sparse Tensor Iteration Types.
77+
//===----------------------------------------------------------------------===//
78+
79+
def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
80+
let mnemonic = "iter_space";
81+
82+
let description = [{
83+
A sparse iteration space that represents an abstract N-D (sparse) iteration space
84+
extracted from a sparse tensor, i.e., a set of (crd_0, crd_1, ..., crd_N) for
85+
every stored element (usually nonzeros) in a sparse tensor between the specified
86+
[$loLvl, $hiLvl) levels.
87+
88+
Examples:
89+
90+
```mlir
91+
// An iteration space extracted from a CSR tensor between levels [0, 2).
92+
!iter_space<#CSR, lvls = 0 to 2>
93+
```
94+
}];
95+
96+
let parameters = (ins
97+
SparseTensorEncodingAttr : $encoding,
98+
"Level" : $loLvl,
99+
"Level" : $hiLvl
100+
);
101+
102+
let extraClassDeclaration = [{
103+
/// The the dimension of the iteration space.
104+
unsigned getSpaceDim() const {
105+
return getHiLvl() - getLoLvl();
106+
}
107+
108+
/// Get the level types for the iteration space.
109+
ArrayRef<LevelType> getLvlTypes() const {
110+
return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
111+
}
112+
113+
/// Whether the iteration space is unique (i.e., no duplicated coordinate).
114+
bool isUnique() {
115+
return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
116+
}
117+
118+
/// Get the corresponding iterator type.
119+
::mlir::sparse_tensor::IteratorType getIteratorType() const;
120+
}];
121+
122+
let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
123+
}
124+
125+
def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
126+
let mnemonic = "iterator";
127+
128+
let description = [{
129+
An iterator that points to the current element in the corresponding iteration space.
130+
131+
Examples:
132+
133+
```mlir
134+
// An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
135+
!iterator<#CSR, lvls = 0 to 2>
136+
```
137+
}];
138+
139+
let parameters = (ins
140+
SparseTensorEncodingAttr : $encoding,
141+
"Level" : $loLvl,
142+
"Level" : $hiLvl
143+
);
144+
145+
let extraClassDeclaration = [{
146+
/// Get the corresponding iteration space type.
147+
::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;
148+
149+
unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
150+
ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
151+
bool isUnique() { return getIterSpaceType().isUnique(); }
152+
}];
153+
154+
let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
155+
}
156+
157+
def IsSparseSparseIterSpaceTypePred
158+
: CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
159+
160+
def IsSparseSparseIteratorTypePred
161+
: CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
162+
163+
def AnySparseIterSpace
164+
: Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
165+
"::mlir::sparse_tensor::IterSpaceType">;
166+
167+
def AnySparseIterator
168+
: Type<IsSparseSparseIteratorTypePred, "sparse iterator",
169+
"::mlir::sparse_tensor::IteratorType">;
170+
171+
75172
#endif // SPARSETENSOR_TYPES

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

+110
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
3131
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
3232

33+
// Forward declarations, following custom print/parsing methods are referenced
34+
// by the generated code for SparseTensorTypes.td.
35+
static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
36+
mlir::sparse_tensor::Level &,
37+
mlir::sparse_tensor::Level &);
38+
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
39+
mlir::sparse_tensor::Level);
40+
3341
#define GET_TYPEDEF_CLASSES
3442
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
3543

@@ -1953,6 +1961,108 @@ LogicalResult SortOp::verify() {
19531961
return success();
19541962
}
19551963

1964+
//===----------------------------------------------------------------------===//
1965+
// Sparse Tensor Iteration Operations.
1966+
//===----------------------------------------------------------------------===//
1967+
1968+
IterSpaceType IteratorType::getIterSpaceType() const {
1969+
return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
1970+
getHiLvl());
1971+
}
1972+
1973+
IteratorType IterSpaceType::getIteratorType() const {
1974+
return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
1975+
}
1976+
1977+
/// Parses a level range in the form "$lo `to` $hi"
1978+
/// or simply "$lo" if $hi - $lo = 1
1979+
static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
1980+
Level &lvlHi) {
1981+
if (parser.parseInteger(lvlLo))
1982+
return failure();
1983+
1984+
if (succeeded(parser.parseOptionalKeyword("to"))) {
1985+
if (parser.parseInteger(lvlHi))
1986+
return failure();
1987+
} else {
1988+
lvlHi = lvlLo + 1;
1989+
}
1990+
1991+
if (lvlHi <= lvlLo)
1992+
parser.emitError(parser.getNameLoc(),
1993+
"expect larger level upper bound than lower bound");
1994+
1995+
return success();
1996+
}
1997+
1998+
/// Parses a level range in the form "$lo `to` $hi"
1999+
/// or simply "$lo" if $hi - $lo = 1
2000+
static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2001+
IntegerAttr &lvlHiAttr) {
2002+
Level lvlLo, lvlHi;
2003+
if (parseLevelRange(parser, lvlLo, lvlHi))
2004+
return failure();
2005+
2006+
lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2007+
lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2008+
return success();
2009+
}
2010+
2011+
/// Prints a level range in the form "$lo `to` $hi"
2012+
/// or simply "$lo" if $hi - $lo = 1
2013+
static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2014+
2015+
if (lo + 1 == hi)
2016+
p << lo;
2017+
else
2018+
p << lo << " to " << hi;
2019+
}
2020+
2021+
/// Prints a level range in the form "$lo `to` $hi"
2022+
/// or simply "$lo" if $hi - $lo = 1
2023+
static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2024+
IntegerAttr lvlHi) {
2025+
unsigned lo = lvlLo.getValue().getZExtValue();
2026+
unsigned hi = lvlHi.getValue().getZExtValue();
2027+
printLevelRange(p, lo, hi);
2028+
}
2029+
2030+
LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2031+
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2032+
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2033+
SmallVectorImpl<mlir::Type> &ret) {
2034+
2035+
ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2036+
SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2037+
ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2038+
adaptor.getHiLvl()));
2039+
return success();
2040+
}
2041+
2042+
LogicalResult ExtractIterSpaceOp::verify() {
2043+
if (getLoLvl() >= getHiLvl())
2044+
return emitOpError("expected smaller level low than level high");
2045+
2046+
TypedValue<IteratorType> pIter = getParentIter();
2047+
if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2048+
return emitOpError(
2049+
"parent iterator should be specified iff level lower bound equals 0");
2050+
}
2051+
2052+
if (pIter) {
2053+
IterSpaceType spaceTp = getResultSpace().getType();
2054+
if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2055+
return emitOpError(
2056+
"mismatch in parent iterator encoding and iteration space encoding.");
2057+
2058+
if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2059+
return emitOpError("parent iterator should be used to extract an "
2060+
"iteration space from a consecutive level.");
2061+
}
2062+
2063+
return success();
2064+
}
2065+
19562066
/// Materialize a single constant operation from a given attribute value with
19572067
/// the desired resultant type.
19582068
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,

mlir/test/Dialect/SparseTensor/invalid.mlir

+82
Original file line numberDiff line numberDiff line change
@@ -1012,3 +1012,85 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
10121012
sparse_tensor.print %arg0 : tensor<10x10xf64>
10131013
return
10141014
}
1015+
1016+
// -----
1017+
1018+
#COO = #sparse_tensor.encoding<{
1019+
map = (i, j) -> (
1020+
i : compressed(nonunique),
1021+
j : singleton(soa)
1022+
)
1023+
}>
1024+
1025+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
1026+
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
1027+
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
1028+
return
1029+
}
1030+
1031+
// -----
1032+
1033+
#COO = #sparse_tensor.encoding<{
1034+
map = (i, j) -> (
1035+
i : compressed(nonunique),
1036+
j : singleton(soa)
1037+
)
1038+
}>
1039+
1040+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
1041+
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
1042+
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1043+
return
1044+
}
1045+
1046+
// -----
1047+
1048+
#COO = #sparse_tensor.encoding<{
1049+
map = (i, j) -> (
1050+
i : compressed(nonunique),
1051+
j : singleton(soa)
1052+
)
1053+
}>
1054+
1055+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
1056+
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
1057+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
1058+
return
1059+
}
1060+
1061+
// -----
1062+
1063+
#COO = #sparse_tensor.encoding<{
1064+
map = (i, j) -> (
1065+
i : compressed(nonunique),
1066+
j : singleton(soa)
1067+
)
1068+
}>
1069+
1070+
#CSR = #sparse_tensor.encoding<{
1071+
map = (i, j) -> (
1072+
i : dense,
1073+
j : compressed
1074+
)
1075+
}>
1076+
1077+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
1078+
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
1079+
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
1080+
return
1081+
}
1082+
1083+
// -----
1084+
1085+
#COO = #sparse_tensor.encoding<{
1086+
map = (i, j) -> (
1087+
i : compressed(nonunique),
1088+
j : singleton(soa)
1089+
)
1090+
}>
1091+
1092+
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
1093+
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
1094+
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1095+
return
1096+
}

0 commit comments

Comments
 (0)