@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
15
15
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
16
16
include "mlir/Interfaces/InferTypeOpInterface.td"
17
17
include "mlir/Interfaces/SideEffectInterfaces.td"
18
+ include "mlir/Interfaces/ControlFlowInterfaces.td"
19
+ include "mlir/Interfaces/LoopLikeInterface.td"
18
20
19
21
//===----------------------------------------------------------------------===//
20
22
// Base class.
@@ -1277,7 +1279,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
1277
1279
1278
1280
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
1279
1281
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1280
- "ForeachOp"]>]>,
1282
+ "ForeachOp", "IterateOp" ]>]>,
1281
1283
Arguments<(ins Variadic<AnyType>:$results)> {
1282
1284
let summary = "Yield from sparse_tensor set-like operations";
1283
1285
let description = [{
@@ -1430,6 +1432,154 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
1430
1432
let hasVerifier = 1;
1431
1433
}
1432
1434
1435
+ //===----------------------------------------------------------------------===//
1436
+ // Sparse Tensor Iteration Operations.
1437
+ //===----------------------------------------------------------------------===//
1438
+
1439
+ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1440
+ [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1441
+
1442
+ let arguments = (ins AnySparseTensor:$tensor,
1443
+ Optional<AnySparseIterator>:$parentIter,
1444
+ LevelAttr:$loLvl, LevelAttr:$hiLvl);
1445
+
1446
+ let results = (outs AnySparseIterSpace:$resultSpace);
1447
+
1448
+ let summary = "Extract an iteration space from a sparse tensor between certain levels";
1449
+ let description = [{
1450
+ Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
1451
+ certian (consecutive) levels.
1452
+
1453
+ `tensor`: the input sparse tensor that defines the iteration space.
1454
+ `parentIter`: the iterator for the previous level, at which the iteration space
1455
+ at the current levels will be extracted.
1456
+ `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
1457
+ the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
1458
+ iteration space.
1459
+
1460
+ Example:
1461
+ ```mlir
1462
+ // Extracts a 1-D iteration space from a COO tensor at level 1.
1463
+ %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
1464
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1465
+ ```
1466
+ }];
1467
+
1468
+
1469
+ let extraClassDeclaration = [{
1470
+ std::pair<Level, Level> getLvlRange() {
1471
+ return std::make_pair(getLoLvl(), getHiLvl());
1472
+ }
1473
+ unsigned getSpaceDim() {
1474
+ return getHiLvl() - getLoLvl();
1475
+ }
1476
+ ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1477
+ return getResultSpace().getType().getLvlTypes();
1478
+ }
1479
+ }];
1480
+
1481
+ let hasVerifier = 1;
1482
+ let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1483
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1484
+ }
1485
+
1486
+ def IterateOp : SparseTensor_Op<"iterate",
1487
+ [RecursiveMemoryEffects, RecursivelySpeculatable,
1488
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
1489
+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1490
+ "getYieldedValuesMutable"]>,
1491
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
1492
+ ["getEntrySuccessorOperands"]>,
1493
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1494
+
1495
+ let arguments = (ins AnySparseIterSpace:$iterSpace,
1496
+ Variadic<AnyType>:$initArgs,
1497
+ LevelSetAttr:$crdUsedLvls);
1498
+ let results = (outs Variadic<AnyType>:$results);
1499
+ let regions = (region SizedRegion<1>:$region);
1500
+
1501
+ let summary = "Iterate over a sparse iteration space";
1502
+ let description = [{
1503
+ The `sparse_tensor.iterate` operations represents a loop over the
1504
+ provided iteration space extracted from a specific sparse tensor.
1505
+ The operation defines an SSA value for a sparse iterator that points
1506
+ to the current stored element in the sparse tensor and SSA values
1507
+ for coordinates of the stored element. The coordinates are always
1508
+ converted to `index` type despite of the underlying sparse tensor
1509
+ storage. When coordinates are not used, the SSA values can be skipped
1510
+ by `_` symbols, which usually leads to simpler generated code after
1511
+ sparsification. For example:
1512
+
1513
+ ```mlir
1514
+ // The coordinate for level 0 is not used when iterating over a 2-D
1515
+ // iteration space.
1516
+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1517
+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1518
+ ```
1519
+
1520
+ `sparse_tensor.iterate` can also operate on loop-carried variables
1521
+ and returns the final values after loop termination.
1522
+ The initial values of the variables are passed as additional SSA operands
1523
+ to the iterator SSA value and used coordinate SSA values mentioned
1524
+ above. The operation region has an argument for the iterator, variadic
1525
+ arguments for specified (used) coordiates and followed by one argument
1526
+ for each loop-carried variable, representing the value of the variable
1527
+ at the current iteration.
1528
+ The body region must contain exactly one block that terminates with
1529
+ `sparse_tensor.yield`.
1530
+
1531
+ `sparse_tensor.iterate` results hold the final values after the last
1532
+ iteration. If the `sparse_tensor.iterate` defines any values, a yield
1533
+ must be explicitly present.
1534
+ The number and types of the `sparse_tensor.iterate` results must match
1535
+ the initial values in the iter_args binding and the yield operands.
1536
+
1537
+
1538
+ A nested `sparse_tensor.iterate` example that prints all the coordinates
1539
+ stored in the sparse input:
1540
+
1541
+ ```mlir
1542
+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1543
+ // Iterates over the first level of %sp
1544
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1545
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1546
+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1547
+ // Iterates over the second level of %sp
1548
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1549
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1550
+ %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1551
+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1552
+ vector.print %crd0 : index
1553
+ vector.print %crd1 : index
1554
+ }
1555
+ }
1556
+ }
1557
+
1558
+ ```
1559
+ }];
1560
+
1561
+ let extraClassDeclaration = [{
1562
+ unsigned getSpaceDim() {
1563
+ return getIterSpace().getType().getSpaceDim();
1564
+ }
1565
+ BlockArgument getIterator() {
1566
+ return getRegion().getArguments().front();
1567
+ }
1568
+ Block::BlockArgListType getCrds() {
1569
+ // The first block argument is iterator, the remaining arguments are
1570
+ // referenced coordinates.
1571
+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1572
+ }
1573
+ unsigned getNumRegionIterArgs() {
1574
+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1575
+ }
1576
+ }];
1577
+
1578
+ let hasVerifier = 1;
1579
+ let hasRegionVerifier = 1;
1580
+ let hasCustomAssemblyFormat = 1;
1581
+ }
1582
+
1433
1583
//===----------------------------------------------------------------------===//
1434
1584
// Sparse Tensor Debugging and Test-Only Operations.
1435
1585
//===----------------------------------------------------------------------===//
0 commit comments