12
12
13
13
#include " mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14
14
15
+ #include " mlir/Analysis/SliceAnalysis.h"
16
+ #include " mlir/Analysis/TopologicalSortUtils.h"
15
17
#include " mlir/Dialect/Affine/IR/AffineOps.h"
16
18
#include " mlir/Dialect/Arith/IR/Arith.h"
17
19
#include " mlir/Dialect/Arith/Utils/Utils.h"
@@ -1580,33 +1582,163 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1580
1582
return success ();
1581
1583
}
1582
1584
1583
- // / Fetches the OpOperand of the only user (and use) of the value `val` which
1584
- // / implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
1585
- // / failure otherwise.
1586
- static FailureOr<OpOperand *> getConsumerFromUses (Value val,
1587
- Block *containingOpBlock) {
1588
- // Check that the value has exactly one use which isn't a scf.yield or a
1589
- // tensor.parallel_insert_slice op.
1590
- OpOperand *operand = nullptr ;
1591
- for (OpOperand &opOperand : val.getUses ()) {
1592
- Operation *consumerOp = opOperand.getOwner ();
1593
- if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
1594
- continue ;
1595
- if (operand)
1596
- return failure ();
1597
- // TODO: We have to init result of consumer before scf.for, use
1598
- // DestinationStyleOpInterface to get result shape from init for now.
1599
- // Add support for other op such as op has InferTypeOpInterface.
1600
- if (!isa<TilingInterface>(consumerOp) ||
1601
- !isa<DestinationStyleOpInterface>(consumerOp))
1585
+ // / An utility to get the first user of the given loopOp. If any of user stay in
1586
+ // / different block of loopOp, return failure.
1587
+ static FailureOr<Operation *> getFirstUserOfLoop (Operation *loopOp) {
1588
+ if (!isa<LoopLikeOpInterface>(loopOp))
1589
+ return failure ();
1590
+ Operation *firstUserOfLoop = nullptr ;
1591
+ for (Operation *userOp : loopOp->getUsers ()) {
1592
+ // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1593
+ // block with any other types of operation. Thus, just redirecting to its
1594
+ // parent `InParallelOp`. E.g.
1595
+ //
1596
+ // ```
1597
+ // %1 = scf.for {
1598
+ // ...
1599
+ // }
1600
+ // %2 = consumerOp ins(%1, ...)
1601
+ // scf.forall.in_parallel {
1602
+ // tensor.parallel_insert_slice %1
1603
+ // }
1604
+ // ```
1605
+ // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1606
+ // same block with `consumerOp`.
1607
+ if (isa<tensor::ParallelInsertSliceOp>(userOp))
1608
+ userOp = userOp->getParentOfType <scf::InParallelOp>();
1609
+
1610
+ if (loopOp->getBlock () != userOp->getBlock ())
1602
1611
return failure ();
1603
- if (containingOpBlock != consumerOp->getBlock ())
1612
+
1613
+ if (!firstUserOfLoop || userOp->isBeforeInBlock (firstUserOfLoop))
1614
+ firstUserOfLoop = userOp;
1615
+ }
1616
+ return firstUserOfLoop;
1617
+ }
1618
+
1619
+ // / This utility currently checks whether the first userOp of loop is NOT before
1620
+ // / the last defineOp of consumer operand. Because that we need to move the
1621
+ // / whole loop structure right before the `firstUserOfLoop`. This utility thus
1622
+ // / helps ensuring that no invalid IR is formed, i.e. no backward slice of
1623
+ // / consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1624
+ // /
1625
+ // / ```
1626
+ // / %0 = scf.for() {
1627
+ // / ...
1628
+ // / }
1629
+ // / ...
1630
+ // / %1 = firstUserOfLoop(%0)
1631
+ // / ...
1632
+ // / %2 = lastDefOfConsumerOperand
1633
+ // / ...
1634
+ // / %3 = consumerOp(%2)
1635
+ // / ```
1636
+ // /
1637
+ // / If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
1638
+ // / be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
1639
+ // / use-def chain violation:
1640
+ // /
1641
+ // / ```
1642
+ // / %0:2 = scf.for() {
1643
+ // / // use before define error
1644
+ // / %3 = tiledConsumerOp(%2)
1645
+ // / }
1646
+ // / %1 = firstUserOfLoop(%0)
1647
+ // / ...
1648
+ // / %2 = lastDefOfConsumerOperand
1649
+ // / ```
1650
+ // /
1651
+ // / @param loopOp: loop operation
1652
+ // / @param consumerOp: consumer operation
1653
+ // / @param reorderOperations: the flag controls whether to reorder the backward
1654
+ // / slice w.r.t. the defineOp of `consumerOp` operands.
1655
+ // / @return: computed backward slice of consumerOp, but excluding those already
1656
+ // / dominates `firstUserOfLoop`.
1657
+ static FailureOr<llvm::SetVector<Operation *>>
1658
+ checkAssumptionForLoop (Operation *loopOp, Operation *consumerOp,
1659
+ bool reorderOperations) {
1660
+ FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop (loopOp);
1661
+ if (failed (firstUserOfLoop))
1662
+ return failure ();
1663
+
1664
+ BackwardSliceOptions options;
1665
+ DominanceInfo dominanceInfo;
1666
+ options.inclusive = true ;
1667
+ options.omitBlockArguments = true ;
1668
+ bool includeLoopOp = false ;
1669
+ options.filter = [&](Operation *op) {
1670
+ if (op == loopOp) {
1671
+ includeLoopOp = true ;
1672
+ return false ;
1673
+ }
1674
+ // Cut off the slice to not include any operation that already dominates
1675
+ // firstUserOfLoop.
1676
+ return !dominanceInfo.properlyDominates (op, *firstUserOfLoop);
1677
+ };
1678
+ llvm::SetVector<Operation *> slice;
1679
+ for (auto operand : consumerOp->getOperands ()) {
1680
+ getBackwardSlice (operand, &slice, options);
1681
+ }
1682
+
1683
+ if (!slice.empty ()) {
1684
+ // If consumerOp has one producer, which is also the user of loopOp.
1685
+ // E.g.
1686
+ // ```
1687
+ // %0 = %loopOp
1688
+ // %1 = consumerOp1 ins(%0)
1689
+ // %2 = consumerOp2 ins(%0, %1)
1690
+ // ```
1691
+ // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1692
+ // consumerOp1 has already been fused into loopOp before.
1693
+ if (includeLoopOp || !reorderOperations)
1604
1694
return failure ();
1605
- operand = &opOperand;
1606
1695
}
1607
1696
1608
- if (operand)
1609
- return operand;
1697
+ return slice;
1698
+ }
1699
+
1700
+ // / Fetches the OpOperand of the first valid user (and use) of the value `val`
1701
+ // / which implements `TilingInterface` and `DestinationStyleOpInterface`.
1702
+ // / Returns failure otherwise.
1703
+ static FailureOr<OpOperand *> getConsumerFromLoopUses (RewriterBase &rewriter,
1704
+ Operation *loopOp,
1705
+ unsigned resultNumber) {
1706
+ if (!isa<LoopLikeOpInterface>(loopOp))
1707
+ return failure ();
1708
+ Value val = loopOp->getResult (resultNumber);
1709
+ Block *loopBlock = loopOp->getBlock ();
1710
+ for (OpOperand &opOperand : val.getUses ()) {
1711
+ Operation *consumerOp = opOperand.getOwner ();
1712
+ // Step 1. Check if the user is tilable.
1713
+ if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
1714
+ // TODO: We have to init result of consumer before scf.for, use
1715
+ // DestinationStyleOpInterface to get result shape from init for now. Add
1716
+ // support for other op such as op has InferTypeOpInterface.
1717
+ continue ;
1718
+ }
1719
+ // Step 2. Check if user stay in the same block.
1720
+ if (loopBlock != consumerOp->getBlock ())
1721
+ continue ;
1722
+ // Step 3. Check if user has succeeding user. Otherwise, it usually
1723
+ // represents already tiled.
1724
+ if (consumerOp->use_empty ())
1725
+ continue ;
1726
+ // Step 4. Check assumption for loop with `reorderOperations` enabled.
1727
+ FailureOr<llvm::SetVector<Operation *>> slice =
1728
+ checkAssumptionForLoop (loopOp, consumerOp, true );
1729
+ if (failed (slice))
1730
+ continue ;
1731
+ // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
1732
+ if (!slice->empty ()) {
1733
+ mlir::topologicalSort (*slice);
1734
+ FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop (loopOp);
1735
+ assert (succeeded (firstUserOfLoop) && " First user of loop is not found" );
1736
+ for (auto op : *slice) {
1737
+ rewriter.moveOpBefore (op, *firstUserOfLoop);
1738
+ }
1739
+ }
1740
+ return &opOperand;
1741
+ }
1610
1742
return failure ();
1611
1743
}
1612
1744
@@ -1659,7 +1791,8 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
1659
1791
// / 1. tensor.insert_slice has scf.yield as its only user.
1660
1792
// / 2. scf.for's corresponding result has only one use.
1661
1793
static FailureOr<OpOperand *>
1662
- getUntiledConsumerFromSlice (tensor::InsertSliceOp candidateSliceOp) {
1794
+ getUntiledConsumerFromSlice (RewriterBase &rewriter,
1795
+ tensor::InsertSliceOp candidateSliceOp) {
1663
1796
if (failed (checkAssumptionForFusingConsumer (candidateSliceOp)))
1664
1797
return failure ();
1665
1798
Value sliceResult = candidateSliceOp.getResult ();
@@ -1672,15 +1805,15 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1672
1805
if (!forOp)
1673
1806
return failure ();
1674
1807
scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf (forOp).front ();
1675
- Value resultingValue = topLevelForOp->getResult (resultNumber);
1676
1808
1677
- return getConsumerFromUses (resultingValue , topLevelForOp-> getBlock () );
1809
+ return getConsumerFromLoopUses (rewriter , topLevelForOp, resultNumber );
1678
1810
}
1679
1811
1680
1812
// / Fetch the first untiled consumer of a scf.forall's result which is yielded
1681
1813
// / by a tensor.parallel_insert_slice.
1682
1814
static FailureOr<OpOperand *>
1683
- getUntiledConsumerFromSlice (tensor::ParallelInsertSliceOp candidateSliceOp) {
1815
+ getUntiledConsumerFromSlice (RewriterBase &rewriter,
1816
+ tensor::ParallelInsertSliceOp candidateSliceOp) {
1684
1817
// Step 1. Fetch the corresponding output
1685
1818
Value sliceDest = candidateSliceOp.getDest ();
1686
1819
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
@@ -1693,45 +1826,22 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
1693
1826
auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1694
1827
if (!forallOp)
1695
1828
return failure ();
1696
- Value resultingValue =
1697
- forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg));
1698
-
1699
- return getConsumerFromUses (resultingValue, containingOp->getBlock ());
1700
- }
1829
+ unsigned resultNumber =
1830
+ forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg))
1831
+ .getResultNumber ();
1701
1832
1702
- // / This utility currently checks whether the loop either :-
1703
- // / 1. Yields exactly one result.
1704
- // / 2. Has consumer op as its first user and other users to be in the same
1705
- // / containing block as that of consumer op's. Currently we clone the loop op
1706
- // / right before the consumer op in order to maintain a valid def-use chain.
1707
- // / This utility thus helps ensuring that no invalid IR is formed due to the
1708
- // / same.
1709
- static LogicalResult checkAssumptionForLoop (Operation *loopOp,
1710
- Operation *consumerOp) {
1711
- // Check if the loop op yields one result.
1712
- if (loopOp->getNumResults () == 1 )
1713
- return success ();
1714
- // Check if the consumerOp is the first user of the loopOp and if other users
1715
- // are in the same containing block as that of consumer op's.
1716
- Block *parentBlock = consumerOp->getBlock ();
1717
- for (Operation *userOp : loopOp->getUsers ()) {
1718
- if (userOp == consumerOp)
1719
- continue ;
1720
- if (parentBlock != userOp->getBlock () ||
1721
- !consumerOp->isBeforeInBlock (userOp))
1722
- return failure ();
1723
- }
1724
- return success ();
1833
+ return getConsumerFromLoopUses (rewriter, containingOp, resultNumber);
1725
1834
}
1726
1835
1727
1836
// / A utility to fetch an untiled consumer of
1728
1837
// / tensor.insert_slice/tensor.parallel_insert_slice.
1729
- static FailureOr<OpOperand *> getUntiledConsumerFromSlice (Operation *sliceOp) {
1838
+ static FailureOr<OpOperand *>
1839
+ getUntiledConsumerFromSlice (RewriterBase &rewriter, Operation *sliceOp) {
1730
1840
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1731
- return getUntiledConsumerFromSlice (insertSlice);
1841
+ return getUntiledConsumerFromSlice (rewriter, insertSlice);
1732
1842
} else if (auto parallelInsertSlice =
1733
1843
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1734
- return getUntiledConsumerFromSlice (parallelInsertSlice);
1844
+ return getUntiledConsumerFromSlice (rewriter, parallelInsertSlice);
1735
1845
} else {
1736
1846
return failure ();
1737
1847
}
@@ -1751,7 +1861,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1751
1861
// 1. Get the consumer of scf.for for the result yielded by
1752
1862
// tensor.insert_slice/parallel_insert_slice.
1753
1863
FailureOr<OpOperand *> maybeConsumerOpOperand =
1754
- getUntiledConsumerFromSlice (candidateSliceOp);
1864
+ getUntiledConsumerFromSlice (rewriter, candidateSliceOp);
1755
1865
if (failed (maybeConsumerOpOperand)) {
1756
1866
return rewriter.notifyMatchFailure (candidateSliceOp,
1757
1867
" could not fetch consumer to fuse" );
@@ -1787,11 +1897,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1787
1897
1788
1898
LoopLikeOpInterface outerMostLoop = nestedLoops.front ();
1789
1899
1790
- if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp))) {
1900
+ // Check assumption for loop with `reorderOperations` disabled.
1901
+ if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp, false ))) {
1791
1902
return rewriter.notifyMatchFailure (
1792
- outerMostLoop,
1793
- " containing loop op should either yield just one value or "
1794
- " have the consumer op as its first user" );
1903
+ outerMostLoop, " the first user of loop should not dominate any define "
1904
+ " of consumer operand(s)" );
1795
1905
}
1796
1906
1797
1907
OpBuilder::InsertionGuard g (rewriter);
@@ -1812,9 +1922,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1812
1922
1813
1923
Location loc = outerMostLoop->getLoc ();
1814
1924
1815
- // 3. Move the whole loop structure right before consumer Op, the dominance
1816
- // should be already ensured by `checkAssumptionForLoop`.
1817
- rewriter.moveOpBefore (outerMostLoop, consumerOp);
1925
+ // 3. Move the whole loop structure right before firstUserOfLoop, the
1926
+ // dominance should be already ensured by `checkAssumptionForLoop`.
1927
+ FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop (outerMostLoop);
1928
+ if (failed (firstUserOfLoop)) {
1929
+ return rewriter.notifyMatchFailure (
1930
+ outerMostLoop, " could not find the first user of outer most loop" );
1931
+ }
1932
+ rewriter.moveOpBefore (outerMostLoop, *firstUserOfLoop);
1818
1933
1819
1934
// 4. Set insertion point before terminator op of the loop and create a new
1820
1935
// tensor.insert_slice. In the scf.for case this is a clone of the
0 commit comments