Skip to content

Commit 44c02a1

Browse files
author
Peiming Liu
committed
[mlir][sparse] introduce sparse_tensor.iterate operation
1 parent 481bd5d commit 44c02a1

File tree

7 files changed

+519
-1
lines changed

7 files changed

+519
-1
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

+38
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
#include "mlir/IR/OpDefinition.h"
1818
#include "mlir/IR/OpImplementation.h"
1919
#include "mlir/IR/TensorEncoding.h"
20+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
22+
#include "mlir/Interfaces/LoopLikeInterface.h"
2123
#include "mlir/Interfaces/SideEffectInterfaces.h"
2224

25+
#include "llvm/ADT/bit.h"
26+
2327
//===----------------------------------------------------------------------===//
2428
//
2529
// Type aliases to help code be more self-documenting. Unfortunately
@@ -41,6 +45,40 @@ using Level = uint64_t;
4145
/// including the value `ShapedType::kDynamic` (for shapes).
4246
using Size = int64_t;
4347

48+
/// A simple wrapper to encode a bitset of defined (at most 64) levels.
49+
class LevelSet {
50+
uint64_t bits = 0;
51+
52+
public:
53+
LevelSet() = default;
54+
explicit LevelSet(uint64_t bits) : bits(bits) {}
55+
operator uint64_t() const { return bits; }
56+
57+
LevelSet &set(unsigned i) {
58+
assert(i < 64);
59+
bits |= 1 << i;
60+
return *this;
61+
}
62+
63+
LevelSet &operator|=(LevelSet lhs) {
64+
bits |= static_cast<uint64_t>(lhs);
65+
return *this;
66+
}
67+
68+
LevelSet &lshift(unsigned offset) {
69+
bits = bits << offset;
70+
return *this;
71+
}
72+
73+
bool operator[](unsigned i) const {
74+
assert(i < 64);
75+
return (bits & (1 << i)) != 0;
76+
}
77+
78+
unsigned count() const { return llvm::popcount(bits); }
79+
bool empty() const { return bits == 0; }
80+
};
81+
4482
} // namespace sparse_tensor
4583
} // namespace mlir
4684

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

+15
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
1919
list<Trait> traits = []>
2020
: AttrDef<SparseTensor_Dialect, name, traits>;
2121

22+
//===----------------------------------------------------------------------===//
23+
// A simple bitset attribute wrapped over a single int64_t to encode a set of
24+
// sparse tensor levels.
25+
//===----------------------------------------------------------------------===//
26+
27+
def LevelSetAttr :
28+
TypedAttrBase<
29+
I64, "IntegerAttr",
30+
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
31+
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
32+
"LevelSet attribute"> {
33+
let returnType = [{::mlir::sparse_tensor::LevelSet}];
34+
let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
35+
}
36+
2237
//===----------------------------------------------------------------------===//
2338
// These attributes are just like `IndexAttr` except that they clarify whether
2439
// the index refers to a dimension (an axis of the semantic tensor) or a level

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

+100-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
1515
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1616
include "mlir/Interfaces/InferTypeOpInterface.td"
1717
include "mlir/Interfaces/SideEffectInterfaces.td"
18+
include "mlir/Interfaces/ControlFlowInterfaces.td"
19+
include "mlir/Interfaces/LoopLikeInterface.td"
1820

1921
//===----------------------------------------------------------------------===//
2022
// Base class.
@@ -1277,7 +1279,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
12771279

12781280
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
12791281
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1280-
"ForeachOp"]>]>,
1282+
"ForeachOp", "IterateOp"]>]>,
12811283
Arguments<(ins Variadic<AnyType>:$results)> {
12821284
let summary = "Yield from sparse_tensor set-like operations";
12831285
let description = [{
@@ -1490,6 +1492,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14901492
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
14911493
}
14921494

1495+
def IterateOp : SparseTensor_Op<"iterate",
1496+
[RecursiveMemoryEffects, RecursivelySpeculatable,
1497+
DeclareOpInterfaceMethods<LoopLikeOpInterface,
1498+
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1499+
"getYieldedValuesMutable"]>,
1500+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
1501+
["getEntrySuccessorOperands"]>,
1502+
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1503+
1504+
let arguments = (ins AnySparseIterSpace:$iterSpace,
1505+
Variadic<AnyType>:$initArgs,
1506+
LevelSetAttr:$crdUsedLvls);
1507+
let results = (outs Variadic<AnyType>:$results);
1508+
let regions = (region SizedRegion<1>:$region);
1509+
1510+
let summary = "Iterate over a sparse iteration space";
1511+
let description = [{
1512+
The `sparse_tensor.iterate` operations represents a loop over the
1513+
provided iteration space extracted from a specific sparse tensor.
1514+
The operation defines an SSA value for a sparse iterator that points
1515+
to the current stored element in the sparse tensor and SSA values
1516+
for coordinates of the stored element. The coordinates are always
1517+
converted to `index` type despite of the underlying sparse tensor
1518+
storage. When coordinates are not used, the SSA values can be skipped
1519+
by `_` symbols, which usually leads to simpler generated code after
1520+
sparsification. For example:
1521+
1522+
```mlir
1523+
// The coordinate for level 0 is not used when iterating over a 2-D
1524+
// iteration space.
1525+
%sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1526+
: !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1527+
```
1528+
1529+
`sparse_tensor.iterate` can also operate on loop-carried variables
1530+
and returns the final values after loop termination.
1531+
The initial values of the variables are passed as additional SSA operands
1532+
to the iterator SSA value and used coordinate SSA values mentioned
1533+
above. The operation region has an argument for the iterator, variadic
1534+
arguments for specified (used) coordiates and followed by one argument
1535+
for each loop-carried variable, representing the value of the variable
1536+
at the current iteration.
1537+
The body region must contain exactly one block that terminates with
1538+
`sparse_tensor.yield`.
1539+
1540+
`sparse_tensor.iterate` results hold the final values after the last
1541+
iteration. If the `sparse_tensor.iterate` defines any values, a yield
1542+
must be explicitly present.
1543+
The number and types of the `sparse_tensor.iterate` results must match
1544+
the initial values in the iter_args binding and the yield operands.
1545+
1546+
1547+
A nested `sparse_tensor.iterate` example that prints all the coordinates
1548+
stored in the sparse input:
1549+
1550+
```mlir
1551+
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1552+
// Iterates over the first level of %sp
1553+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1554+
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1555+
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1556+
// Iterates over the second level of %sp
1557+
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1558+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1559+
%r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1560+
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1561+
vector.print %crd0 : index
1562+
vector.print %crd1 : index
1563+
}
1564+
}
1565+
}
1566+
1567+
```
1568+
}];
1569+
1570+
let extraClassDeclaration = [{
1571+
unsigned getSpaceDim() {
1572+
return getIterSpace().getType().getSpaceDim();
1573+
}
1574+
BlockArgument getIterator() {
1575+
return getRegion().getArguments().front();
1576+
}
1577+
Block::BlockArgListType getCrds() {
1578+
// The first block argument is iterator, the remaining arguments are
1579+
// referenced coordinates.
1580+
return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1581+
}
1582+
unsigned getNumRegionIterArgs() {
1583+
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1584+
}
1585+
}];
1586+
1587+
let hasVerifier = 1;
1588+
let hasRegionVerifier = 1;
1589+
let hasCustomAssemblyFormat = 1;
1590+
}
1591+
14931592
//===----------------------------------------------------------------------===//
14941593
// Sparse Tensor Debugging and Test-Only Operations.
14951594
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)