Skip to content

Commit 7a078b6

Browse files
committed
[mlir][linalg] Refine how contiguous loads are identified
Vectorization of `tensor.extract` using contiguous loads (`vector.transfer_read`) was introduced in [1]. This patch updates and refines the existing logic (so that more cases of contiguous can be identified), as well as adds more tests. Specifically, contiguous load operations are identified by making sure that: 1. non-trailing indices for `tensor.extract` are loop invariant (so, e.g., there are no "jumps" from one row to the other between iterations), 2. the trailing index for `tensor.extract` increments by 1 with every loop iteration (so that it's always adjacent elements that are loaded). This patch introduces: * `isLoopInvariantIdx` for step 1., and * `isContiguousLoadIdx` for step 2. These new methods replace: * `isContiguousLoadIdx`, and `isBasedOnIndexOp`. Both approaches lead to similar end-result (none of the existing tests required updating). However, with the updated approach, it's much easier to treat the trailing and non-trailing indices separately and to add more cases for which contiguous loads can be used. [1] https://reviews.llvm.org/D141998 Differential Revision: https://reviews.llvm.org/D145385
1 parent c7fcae5 commit 7a078b6

File tree

2 files changed

+293
-65
lines changed

2 files changed

+293
-65
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 88 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -636,81 +636,112 @@ enum VectorMemoryAccessKind {
636636
Gather
637637
};
638638

639-
/// Check whether /p val can be used for calculating an index for a contiguous
640-
/// load operation. This means that /p val should either:
641-
/// * be invariant with respect to /p linalgOp, or
642-
/// * increment by 1 with every loop iterator (when /p shouldBeConstant is
643-
/// false).
644-
/// Parameters /p trailingLoopDim and /p shouldBeConstant are used to analyze
645-
/// `linalg.index` ops.
646-
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
647-
size_t trailingLoopDim, bool shouldBeConstant) {
648-
auto *block = linalgOp.getBlock();
639+
/// Checks whether /p val can be used for calculating a loop invariant index.
640+
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
649641

650-
// Bail out if this is a block argument for this linalg.generic Op.
642+
auto targetShape = linalgOp.getStaticLoopRanges();
643+
assert(((llvm::count_if(targetShape,
644+
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
645+
"n-D vectors are not yet supported");
646+
assert(targetShape.back() != 1 &&
647+
"1-D vectors with the trailing dim eqaual 1 are not yet supported");
648+
649+
// Blocks outside _this_ linalg.generic are effectively loop invariant.
650+
// However, analysing block arguments for _this_ linalg.generic Op is a bit
651+
// tricky. Just bail out in the latter case.
651652
// TODO: We could try analysing the corresponding affine map here.
652-
if (val.dyn_cast<BlockArgument>())
653+
auto *block = linalgOp.getBlock();
654+
if (isa<BlockArgument>(val))
653655
return llvm::all_of(block->getArguments(),
654656
[&val](Value v) { return (v != val); });
655657

656658
Operation *defOp = val.getDefiningOp();
657659
assert(defOp && "This is neither a block argument nor an operation result");
658660

659-
// We know that we are reading into a 1-D tensor like this:
660-
// `tensor<1x1x4xi32`. Given this assumption, the following Op:
661-
// * `%idx = `linalg.index dim : index`,
662-
// will either:
663-
// 1. produce a constant when `dim` _is not_ the trailing loop dim, or
664-
// 2. increment with stride one when `dim` _is_ the trailing loop dim.
661+
// IndexOp is loop invariant as long as its result remains constant across
662+
// iterations. Given the assumptions on the loop ranges above, only the
663+
// trailing loop dim ever changes.
664+
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
665665
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
666-
return shouldBeConstant ? (indexOp.getDim() != trailingLoopDim)
667-
: (indexOp.getDim() == trailingLoopDim);
666+
return (indexOp.getDim() != trailingLoopDim);
668667

669668
auto *ancestor = block->findAncestorOpInBlock(*defOp);
670669

671-
// Values define outside `linalgOp`.
670+
// Values define outside `linalgOp` are loop invariant.
672671
if (!ancestor)
673672
return true;
674673

675-
// Values defined inside `linalgOp`, which are constant.
676-
if (dyn_cast<arith::ConstantOp>(ancestor))
674+
// Values defined inside `linalgOp`, which are constant, are loop invariant.
675+
if (isa<arith::ConstantOp>(ancestor))
677676
return true;
678677

679-
// Conservatively reject Ops that could lead to non-contiguous accesses.
680-
if (!isa<arith::AddIOp, arith::SubIOp, linalg::IndexOp>(ancestor))
681-
return false;
682-
683678
bool result = true;
684679
for (auto op : ancestor->getOperands())
685-
result &=
686-
isContiguousLoadIdx(linalgOp, op, trailingLoopDim, shouldBeConstant);
680+
result &= isLoopInvariantIdx(linalgOp, op);
687681

688682
return result;
689683
}
690684

691-
/// Check whether the calculation of \p val is based on linalg.index Op with
692-
/// the dim attribute matching \p dim.
693-
static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) {
694-
auto *block = linalgOp.getBlock();
695-
auto targetShape = linalgOp.getStaticLoopRanges();
685+
/// Check whether \p val could be used for calculating the trailing index for a
686+
/// contiguous load operation.
687+
///
688+
/// There are currently 3 types of values that are allowed here:
689+
/// 1. loop-invariant values,
690+
/// 2. values that increment by 1 with every loop iteration,
691+
/// 3. results of basic arithmetic operations (linear and continuous)
692+
/// involving 1., 2. and 3.
693+
/// This method returns True if indeed only such values are used in calculating
694+
/// \p val.
695+
///
696+
/// Additionally, the trailing index for a contiguous load operation should
697+
/// increment by 1 with every loop iteration, i.e. be based on:
698+
/// * `linalg.index <dim>` ,
699+
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
700+
/// updated to `true` when such an op is found.
701+
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
702+
bool &foundIndexOp) {
696703

697-
if (val.isa<BlockArgument>())
698-
return false;
704+
auto targetShape = linalgOp.getStaticLoopRanges();
705+
assert(((llvm::count_if(targetShape,
706+
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
707+
"n-D vectors are not yet supported");
708+
assert(targetShape.back() != 1 &&
709+
"1-D vectors with the trailing dim 1 are not yet supported");
710+
711+
// Blocks outside _this_ linalg.generic are effectively loop invariant.
712+
// However, analysing block arguments for _this_ linalg.generic Op is a bit
713+
// tricky. Just bail out in the latter case.
714+
// TODO: We could try analysing the corresponding affine map here.
715+
auto *block = linalgOp.getBlock();
716+
if (isa<BlockArgument>(val))
717+
return llvm::all_of(block->getArguments(),
718+
[&val](Value v) { return (v != val); });
699719

700720
Operation *defOp = val.getDefiningOp();
701721
assert(defOp && "This is neither a block argument nor an operation result");
702722

703-
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
704-
return (indexOp.getDim() == dim);
723+
// Given the assumption on the loop ranges above, only the trailing loop
724+
// index is not constant.
725+
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
726+
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
727+
foundIndexOp = (indexOp.getDim() == trailingLoopDim);
728+
return true;
729+
}
705730

706731
auto *ancestor = block->findAncestorOpInBlock(*defOp);
707732

708733
if (!ancestor)
709734
return false;
710735

736+
// Conservatively reject Ops that could lead to indices with stride other
737+
// than 1.
738+
if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
739+
ancestor))
740+
return false;
741+
711742
bool result = false;
712743
for (auto op : ancestor->getOperands())
713-
result |= isBasedOnIndexOp(linalgOp, op, dim);
744+
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
714745

715746
return result;
716747
}
@@ -725,7 +756,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
725756

726757
auto targetShape = linalgOp.getStaticLoopRanges();
727758

728-
// Assume that it's a gather load when reading _into_:
759+
// 1. Assume that it's a gather load when reading _into_:
729760
// * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
730761
// * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
731762
// TODO: Relax these conditions.
@@ -736,44 +767,36 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
736767

737768
auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
738769

739-
// Assume that it's a gather load when reading _from_ a tensor for which the
740-
// trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
770+
// 2. Assume that it's a gather load when reading _from_ a tensor for which
771+
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
741772
// TODO: Relax this condition.
742773
if (inputShape.getShape().back() == 1)
743774
return VectorMemoryAccessKind::Gather;
744775

745-
// The trailing loop dim is needed when analyzing ops like:
746-
// * %idx = `linalg.index <dim> : index`.
747-
auto trailingLoopDim = targetShape.size() - 1;
748-
749776
bool isContiguous = true;
750777

751-
// Iterate over all indices. Analyze the way each index is calculated and
752-
// decide whether it is suitable for a contiguous load (e.g. loop invariant).
778+
// 3a. Analyze the leading indices of `extractOp`.
779+
// Look at the way each index is calculated and decide whether it is suitable
780+
// for a contiguous load, i.e. whether it's loop invariant.
753781
auto indices = extractOp.getIndices();
754-
for (auto [i, indexVal] : llvm::enumerate(indices)) {
755-
if (inputShape.getShape()[i] == 1) {
756-
// This index will always be equal 0, so it is a loop-invariant constant.
757-
continue;
758-
}
782+
auto leadIndices = ValueRange(indices.drop_back(1));
759783

760-
// Should this index be loop invariant?
761-
// * _no_ if this is the trailing index,
762-
// * _yes_ otherwise.
763-
auto extractOpBottomIdx = indices.size() - 1;
764-
bool loopInvariantIndex = (i != extractOpBottomIdx);
784+
for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
785+
if (inputShape.getShape()[i] == 1)
786+
continue;
765787

766-
isContiguous &= isContiguousLoadIdx(linalgOp, indexVal, trailingLoopDim,
767-
loopInvariantIndex);
788+
isContiguous &= isLoopInvariantIdx(linalgOp, indexVal);
768789
}
769790

770-
// The trailing index in the extract Op must increment with every iteration,
771-
// which means that it must be based on a loop index. Given the assumption
772-
// on the output tensor, only the trailing loop index is not constant, so
773-
// that's what we need to check against.
791+
// 3b. Analyze the trailing index for `extractOp`.
774792
auto extractOpTrailingIdx = indices.back();
793+
// For contiguous loads, the trailing `extractOp` index should increment with
794+
// every loop iteration. This effectively means that it must be based on the
795+
// trailing loop index. This is what the following bool captures.
796+
bool foundIndexOp = false;
775797
isContiguous &=
776-
isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, trailingLoopDim);
798+
isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
799+
isContiguous &= foundIndexOp;
777800

778801
if (isContiguous) {
779802
LDBG("Found contigous load: " << extractOp);

0 commit comments

Comments
 (0)