@@ -636,81 +636,112 @@ enum VectorMemoryAccessKind {
636
636
Gather
637
637
};
638
638
639
- // / Check whether /p val can be used for calculating an index for a contiguous
640
- // / load operation. This means that /p val should either:
641
- // / * be invariant with respect to /p linalgOp, or
642
- // / * increment by 1 with every loop iterator (when /p shouldBeConstant is
643
- // / false).
644
- // / Parameters /p trailingLoopDim and /p shouldBeConstant are used to analyze
645
- // / `linalg.index` ops.
646
- static bool isContiguousLoadIdx (LinalgOp &linalgOp, Value &val,
647
- size_t trailingLoopDim, bool shouldBeConstant) {
648
- auto *block = linalgOp.getBlock ();
639
+ // / Checks whether /p val can be used for calculating a loop invariant index.
640
+ static bool isLoopInvariantIdx (LinalgOp &linalgOp, Value &val) {
649
641
650
- // Bail out if this is a block argument for this linalg.generic Op.
642
+ auto targetShape = linalgOp.getStaticLoopRanges ();
643
+ assert (((llvm::count_if (targetShape,
644
+ [](int64_t dimSize) { return dimSize > 1 ; }) == 1 )) &&
645
+ " n-D vectors are not yet supported" );
646
+ assert (targetShape.back () != 1 &&
647
+ " 1-D vectors with the trailing dim eqaual 1 are not yet supported" );
648
+
649
+ // Blocks outside _this_ linalg.generic are effectively loop invariant.
650
+ // However, analysing block arguments for _this_ linalg.generic Op is a bit
651
+ // tricky. Just bail out in the latter case.
651
652
// TODO: We could try analysing the corresponding affine map here.
652
- if (val.dyn_cast <BlockArgument>())
653
+ auto *block = linalgOp.getBlock ();
654
+ if (isa<BlockArgument>(val))
653
655
return llvm::all_of (block->getArguments (),
654
656
[&val](Value v) { return (v != val); });
655
657
656
658
Operation *defOp = val.getDefiningOp ();
657
659
assert (defOp && " This is neither a block argument nor an operation result" );
658
660
659
- // We know that we are reading into a 1-D tensor like this:
660
- // `tensor<1x1x4xi32`. Given this assumption, the following Op:
661
- // * `%idx = `linalg.index dim : index`,
662
- // will either:
663
- // 1. produce a constant when `dim` _is not_ the trailing loop dim, or
664
- // 2. increment with stride one when `dim` _is_ the trailing loop dim.
661
+ // IndexOp is loop invariant as long as its result remains constant across
662
+ // iterations. Given the assumptions on the loop ranges above, only the
663
+ // trailing loop dim ever changes.
664
+ auto trailingLoopDim = linalgOp.getStaticLoopRanges ().size () - 1 ;
665
665
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
666
- return shouldBeConstant ? (indexOp.getDim () != trailingLoopDim)
667
- : (indexOp.getDim () == trailingLoopDim);
666
+ return (indexOp.getDim () != trailingLoopDim);
668
667
669
668
auto *ancestor = block->findAncestorOpInBlock (*defOp);
670
669
671
- // Values define outside `linalgOp`.
670
+ // Values define outside `linalgOp` are loop invariant .
672
671
if (!ancestor)
673
672
return true ;
674
673
675
- // Values defined inside `linalgOp`, which are constant.
676
- if (dyn_cast <arith::ConstantOp>(ancestor))
674
+ // Values defined inside `linalgOp`, which are constant, are loop invariant .
675
+ if (isa <arith::ConstantOp>(ancestor))
677
676
return true ;
678
677
679
- // Conservatively reject Ops that could lead to non-contiguous accesses.
680
- if (!isa<arith::AddIOp, arith::SubIOp, linalg::IndexOp>(ancestor))
681
- return false ;
682
-
683
678
bool result = true ;
684
679
for (auto op : ancestor->getOperands ())
685
- result &=
686
- isContiguousLoadIdx (linalgOp, op, trailingLoopDim, shouldBeConstant);
680
+ result &= isLoopInvariantIdx (linalgOp, op);
687
681
688
682
return result;
689
683
}
690
684
691
- // / Check whether the calculation of \p val is based on linalg.index Op with
692
- // / the dim attribute matching \p dim.
693
- static bool isBasedOnIndexOp (LinalgOp &linalgOp, Value &val, size_t dim) {
694
- auto *block = linalgOp.getBlock ();
695
- auto targetShape = linalgOp.getStaticLoopRanges ();
685
+ // / Check whether \p val could be used for calculating the trailing index for a
686
+ // / contiguous load operation.
687
+ // /
688
+ // / There are currently 3 types of values that are allowed here:
689
+ // / 1. loop-invariant values,
690
+ // / 2. values that increment by 1 with every loop iteration,
691
+ // / 3. results of basic arithmetic operations (linear and continuous)
692
+ // / involving 1., 2. and 3.
693
+ // / This method returns True if indeed only such values are used in calculating
694
+ // / \p val.
695
+ // /
696
+ // / Additionally, the trailing index for a contiguous load operation should
697
+ // / increment by 1 with every loop iteration, i.e. be based on:
698
+ // / * `linalg.index <dim>` ,
699
+ // / where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
700
+ // / updated to `true` when such an op is found.
701
+ static bool isContiguousLoadIdx (LinalgOp &linalgOp, Value &val,
702
+ bool &foundIndexOp) {
696
703
697
- if (val.isa <BlockArgument>())
698
- return false ;
704
+ auto targetShape = linalgOp.getStaticLoopRanges ();
705
+ assert (((llvm::count_if (targetShape,
706
+ [](int64_t dimSize) { return dimSize > 1 ; }) == 1 )) &&
707
+ " n-D vectors are not yet supported" );
708
+ assert (targetShape.back () != 1 &&
709
+ " 1-D vectors with the trailing dim 1 are not yet supported" );
710
+
711
+ // Blocks outside _this_ linalg.generic are effectively loop invariant.
712
+ // However, analysing block arguments for _this_ linalg.generic Op is a bit
713
+ // tricky. Just bail out in the latter case.
714
+ // TODO: We could try analysing the corresponding affine map here.
715
+ auto *block = linalgOp.getBlock ();
716
+ if (isa<BlockArgument>(val))
717
+ return llvm::all_of (block->getArguments (),
718
+ [&val](Value v) { return (v != val); });
699
719
700
720
Operation *defOp = val.getDefiningOp ();
701
721
assert (defOp && " This is neither a block argument nor an operation result" );
702
722
703
- if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
704
- return (indexOp.getDim () == dim);
723
+ // Given the assumption on the loop ranges above, only the trailing loop
724
+ // index is not constant.
725
+ auto trailingLoopDim = linalgOp.getStaticLoopRanges ().size () - 1 ;
726
+ if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
727
+ foundIndexOp = (indexOp.getDim () == trailingLoopDim);
728
+ return true ;
729
+ }
705
730
706
731
auto *ancestor = block->findAncestorOpInBlock (*defOp);
707
732
708
733
if (!ancestor)
709
734
return false ;
710
735
736
+ // Conservatively reject Ops that could lead to indices with stride other
737
+ // than 1.
738
+ if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
739
+ ancestor))
740
+ return false ;
741
+
711
742
bool result = false ;
712
743
for (auto op : ancestor->getOperands ())
713
- result |= isBasedOnIndexOp (linalgOp, op, dim );
744
+ result |= isContiguousLoadIdx (linalgOp, op, foundIndexOp );
714
745
715
746
return result;
716
747
}
@@ -725,7 +756,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
725
756
726
757
auto targetShape = linalgOp.getStaticLoopRanges ();
727
758
728
- // Assume that it's a gather load when reading _into_:
759
+ // 1. Assume that it's a gather load when reading _into_:
729
760
// * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
730
761
// * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
731
762
// TODO: Relax these conditions.
@@ -736,44 +767,36 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
736
767
737
768
auto inputShape = extractOp.getTensor ().getType ().cast <ShapedType>();
738
769
739
- // Assume that it's a gather load when reading _from_ a tensor for which the
740
- // trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
770
+ // 2. Assume that it's a gather load when reading _from_ a tensor for which
771
+ // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
741
772
// TODO: Relax this condition.
742
773
if (inputShape.getShape ().back () == 1 )
743
774
return VectorMemoryAccessKind::Gather;
744
775
745
- // The trailing loop dim is needed when analyzing ops like:
746
- // * %idx = `linalg.index <dim> : index`.
747
- auto trailingLoopDim = targetShape.size () - 1 ;
748
-
749
776
bool isContiguous = true ;
750
777
751
- // Iterate over all indices. Analyze the way each index is calculated and
752
- // decide whether it is suitable for a contiguous load (e.g. loop invariant).
778
+ // 3a. Analyze the leading indices of `extractOp`.
779
+ // Look at the way each index is calculated and decide whether it is suitable
780
+ // for a contiguous load, i.e. whether it's loop invariant.
753
781
auto indices = extractOp.getIndices ();
754
- for (auto [i, indexVal] : llvm::enumerate (indices)) {
755
- if (inputShape.getShape ()[i] == 1 ) {
756
- // This index will always be equal 0, so it is a loop-invariant constant.
757
- continue ;
758
- }
782
+ auto leadIndices = ValueRange (indices.drop_back (1 ));
759
783
760
- // Should this index be loop invariant?
761
- // * _no_ if this is the trailing index,
762
- // * _yes_ otherwise.
763
- auto extractOpBottomIdx = indices.size () - 1 ;
764
- bool loopInvariantIndex = (i != extractOpBottomIdx);
784
+ for (auto [i, indexVal] : llvm::enumerate (leadIndices)) {
785
+ if (inputShape.getShape ()[i] == 1 )
786
+ continue ;
765
787
766
- isContiguous &= isContiguousLoadIdx (linalgOp, indexVal, trailingLoopDim,
767
- loopInvariantIndex);
788
+ isContiguous &= isLoopInvariantIdx (linalgOp, indexVal);
768
789
}
769
790
770
- // The trailing index in the extract Op must increment with every iteration,
771
- // which means that it must be based on a loop index. Given the assumption
772
- // on the output tensor, only the trailing loop index is not constant, so
773
- // that's what we need to check against.
791
+ // 3b. Analyze the trailing index for `extractOp`.
774
792
auto extractOpTrailingIdx = indices.back ();
793
+ // For contiguous loads, the trailing `extractOp` index should increment with
794
+ // every loop iteration. This effectively means that it must be based on the
795
+ // trailing loop index. This is what the following bool captures.
796
+ bool foundIndexOp = false ;
775
797
isContiguous &=
776
- isBasedOnIndexOp (linalgOp, extractOpTrailingIdx, trailingLoopDim);
798
+ isContiguousLoadIdx (linalgOp, extractOpTrailingIdx, foundIndexOp);
799
+ isContiguous &= foundIndexOp;
777
800
778
801
if (isContiguous) {
779
802
LDBG (" Found contigous load: " << extractOp);
0 commit comments