Skip to content

Commit 51be8a3

Browse files
author
Peiming Liu
committed
[mlir][sparse] introduce sparse_tensor.iterate operation
1 parent 8aa061f commit 51be8a3

File tree

7 files changed

+519
-1
lines changed

7 files changed

+519
-1
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

+38
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
#include "mlir/IR/OpDefinition.h"
1818
#include "mlir/IR/OpImplementation.h"
1919
#include "mlir/IR/TensorEncoding.h"
20+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
22+
#include "mlir/Interfaces/LoopLikeInterface.h"
2123
#include "mlir/Interfaces/SideEffectInterfaces.h"
2224

25+
#include "llvm/ADT/bit.h"
26+
2327
//===----------------------------------------------------------------------===//
2428
//
2529
// Type aliases to help code be more self-documenting. Unfortunately
@@ -41,6 +45,40 @@ using Level = uint64_t;
4145
/// including the value `ShapedType::kDynamic` (for shapes).
4246
using Size = int64_t;
4347

48+
/// A simple wrapper to encode a bitset of defined (at most 64) levels.
49+
class LevelSet {
50+
uint64_t bits = 0;
51+
52+
public:
53+
LevelSet() = default;
54+
explicit LevelSet(uint64_t bits) : bits(bits) {}
55+
operator uint64_t() const { return bits; }
56+
57+
LevelSet &set(unsigned i) {
58+
assert(i < 64);
59+
bits |= 1 << i;
60+
return *this;
61+
}
62+
63+
LevelSet &operator|=(LevelSet lhs) {
64+
bits |= static_cast<uint64_t>(lhs);
65+
return *this;
66+
}
67+
68+
LevelSet &lshift(unsigned offset) {
69+
bits = bits << offset;
70+
return *this;
71+
}
72+
73+
bool operator[](unsigned i) const {
74+
assert(i < 64);
75+
return (bits & (1 << i)) != 0;
76+
}
77+
78+
unsigned count() const { return llvm::popcount(bits); }
79+
bool empty() const { return bits == 0; }
80+
};
81+
4482
} // namespace sparse_tensor
4583
} // namespace mlir
4684

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

+15
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
1919
list<Trait> traits = []>
2020
: AttrDef<SparseTensor_Dialect, name, traits>;
2121

22+
//===----------------------------------------------------------------------===//
23+
// A simple bitset attribute wrapped over a single int64_t to encode a set of
24+
// sparse tensor levels.
25+
//===----------------------------------------------------------------------===//
26+
27+
def LevelSetAttr :
28+
TypedAttrBase<
29+
I64, "IntegerAttr",
30+
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
31+
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
32+
"LevelSet attribute"> {
33+
let returnType = [{::mlir::sparse_tensor::LevelSet}];
34+
let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
35+
}
36+
2237
//===----------------------------------------------------------------------===//
2338
// These attributes are just like `IndexAttr` except that they clarify whether
2439
// the index refers to a dimension (an axis of the semantic tensor) or a level

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

+100-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
1515
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1616
include "mlir/Interfaces/InferTypeOpInterface.td"
1717
include "mlir/Interfaces/SideEffectInterfaces.td"
18+
include "mlir/Interfaces/ControlFlowInterfaces.td"
19+
include "mlir/Interfaces/LoopLikeInterface.td"
1820

1921
//===----------------------------------------------------------------------===//
2022
// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
13041306

13051307
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
13061308
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1307-
"ForeachOp"]>]> {
1309+
"ForeachOp", "IterateOp"]>]> {
13081310
let summary = "Yield from sparse_tensor set-like operations";
13091311
let description = [{
13101312
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1513,6 +1515,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
15131515
let hasVerifier = 1;
15141516
}
15151517

1518+
def IterateOp : SparseTensor_Op<"iterate",
1519+
[RecursiveMemoryEffects, RecursivelySpeculatable,
1520+
DeclareOpInterfaceMethods<LoopLikeOpInterface,
1521+
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1522+
"getYieldedValuesMutable"]>,
1523+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
1524+
["getEntrySuccessorOperands"]>,
1525+
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1526+
1527+
let summary = "Iterate over a sparse iteration space";
1528+
let description = [{
1529+
The `sparse_tensor.iterate` operations represents a loop over the
1530+
provided iteration space extracted from a specific sparse tensor.
1531+
The operation defines an SSA value for a sparse iterator that points
1532+
to the current stored element in the sparse tensor and SSA values
1533+
for coordinates of the stored element. The coordinates are always
1534+
converted to `index` type despite of the underlying sparse tensor
1535+
storage. When coordinates are not used, the SSA values can be skipped
1536+
by `_` symbols, which usually leads to simpler generated code after
1537+
sparsification. For example:
1538+
1539+
```mlir
1540+
// The coordinate for level 0 is not used when iterating over a 2-D
1541+
// iteration space.
1542+
%sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1543+
: !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1544+
```
1545+
1546+
`sparse_tensor.iterate` can also operate on loop-carried variables
1547+
and returns the final values after loop termination.
1548+
The initial values of the variables are passed as additional SSA operands
1549+
to the iterator SSA value and used coordinate SSA values mentioned
1550+
above. The operation region has an argument for the iterator, variadic
1551+
arguments for specified (used) coordiates and followed by one argument
1552+
for each loop-carried variable, representing the value of the variable
1553+
at the current iteration.
1554+
The body region must contain exactly one block that terminates with
1555+
`sparse_tensor.yield`.
1556+
1557+
`sparse_tensor.iterate` results hold the final values after the last
1558+
iteration. If the `sparse_tensor.iterate` defines any values, a yield
1559+
must be explicitly present.
1560+
The number and types of the `sparse_tensor.iterate` results must match
1561+
the initial values in the iter_args binding and the yield operands.
1562+
1563+
1564+
A nested `sparse_tensor.iterate` example that prints all the coordinates
1565+
stored in the sparse input:
1566+
1567+
```mlir
1568+
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1569+
// Iterates over the first level of %sp
1570+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1571+
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1572+
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1573+
// Iterates over the second level of %sp
1574+
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1575+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1576+
%r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1577+
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1578+
vector.print %crd0 : index
1579+
vector.print %crd1 : index
1580+
}
1581+
}
1582+
}
1583+
1584+
```
1585+
}];
1586+
1587+
let arguments = (ins AnySparseIterSpace:$iterSpace,
1588+
Variadic<AnyType>:$initArgs,
1589+
LevelSetAttr:$crdUsedLvls);
1590+
let results = (outs Variadic<AnyType>:$results);
1591+
let regions = (region SizedRegion<1>:$region);
1592+
1593+
let extraClassDeclaration = [{
1594+
unsigned getSpaceDim() {
1595+
return getIterSpace().getType().getSpaceDim();
1596+
}
1597+
BlockArgument getIterator() {
1598+
return getRegion().getArguments().front();
1599+
}
1600+
Block::BlockArgListType getCrds() {
1601+
// The first block argument is iterator, the remaining arguments are
1602+
// referenced coordinates.
1603+
return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1604+
}
1605+
unsigned getNumRegionIterArgs() {
1606+
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1607+
}
1608+
}];
1609+
1610+
let hasVerifier = 1;
1611+
let hasRegionVerifier = 1;
1612+
let hasCustomAssemblyFormat = 1;
1613+
}
1614+
15161615
//===----------------------------------------------------------------------===//
15171616
// Sparse Tensor Debugging and Test-Only Operations.
15181617
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)