Skip to content

Commit 0c1c0d5

Browse files
Jerry Wuqedawkins
Jerry Wu
andauthored
[MLIR] Add patterns to bubble-up pack and push-down unpack through collapse/expand shape ops (#85297)
Add DataLayoutPropagation patterns to bubble-up pack and push-down unpack through collapse/expand shape ops. --------- Co-authored-by: Quinn Dawkins <[email protected]>
1 parent 1095f71 commit 0c1c0d5

File tree

2 files changed

+462
-1
lines changed

2 files changed

+462
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 302 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Utils/IndexingUtils.h"
1818
#include "mlir/IR/Dominance.h"
1919
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
#include "llvm/ADT/TypeSwitch.h"
2021
#include "llvm/Support/Debug.h"
2122
#include <optional>
2223

@@ -552,6 +553,305 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
552553
ControlPropagationFn controlFn;
553554
};
554555

556+
/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
557+
///
558+
/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
559+
/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
560+
/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
561+
/// non-unit projected dims in pos [2, 3] is 2.
562+
///
563+
/// If all candidates in a reassociation are unit dims, it chooses the
564+
/// inner-most dim pos.
565+
static SmallVector<int64_t>
566+
projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
567+
ArrayRef<ReassociationIndices> reassocIndices,
568+
ArrayRef<int64_t> targetShape) {
569+
SmallVector<int64_t> projectedDimsPos;
570+
for (auto pos : dimsPos) {
571+
// In the case all dims are unit, this will return the inner-most one.
572+
int64_t projectedPos = reassocIndices[pos].back();
573+
for (auto i : llvm::reverse(reassocIndices[pos])) {
574+
int64_t dim = targetShape[i];
575+
if (dim > 1 || ShapedType::isDynamic(dim)) {
576+
projectedPos = i;
577+
break;
578+
}
579+
}
580+
projectedDimsPos.push_back(projectedPos);
581+
}
582+
return projectedDimsPos;
583+
}
584+
585+
/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
586+
static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
587+
ArrayRef<int64_t> shape,
588+
ArrayRef<int64_t> tileSizes) {
589+
for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
590+
int64_t dim = shape[pos];
591+
if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
592+
return false;
593+
}
594+
return true;
595+
}
596+
597+
/// Permutate the reassociation indices and reindex them in the sequence order.
598+
/// Returns the next dim pos in the sequence.
599+
///
600+
/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
601+
/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
602+
/// [[0], [1, 2]].
603+
static int64_t applyPermutationAndReindexReassoc(
604+
SmallVector<ReassociationIndices> &reassocIndices,
605+
ArrayRef<int64_t> permutation) {
606+
applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
607+
int64_t nextPos = 0;
608+
for (ReassociationIndices &indices : reassocIndices) {
609+
for (auto &index : indices) {
610+
index = nextPos;
611+
nextPos += 1;
612+
}
613+
}
614+
return nextPos;
615+
}
616+
617+
/// Bubble up pack op through collapse shape op when the packed dims can be
618+
/// projected to the dims before collapsing. This is possible when the inner
619+
/// tile sizes can divide the projected dims.
620+
///
621+
/// For example:
622+
///
623+
/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
624+
/// : tensor<?x16x4xf32> into tensor<?x4xf32>
625+
/// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
626+
/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
627+
/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
628+
///
629+
/// can be transformed into:
630+
///
631+
/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
632+
/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
633+
/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
634+
/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
635+
/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
636+
static LogicalResult
637+
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
638+
tensor::PackOp packOp,
639+
PatternRewriter &rewriter) {
640+
SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
641+
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
642+
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
643+
644+
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
645+
SmallVector<ReassociationIndices> reassocIndices =
646+
collapseOp.getReassociationIndices();
647+
// Project inner tile pos to the dim pos before collapsing. For example, if
648+
// dims [x, y] is collapsed into [z], packing on dim z can be projected back
649+
// to pack on dim y.
650+
//
651+
// Project to inner-most non-unit dims to increase the chance that they can be
652+
// divided by the inner tile sizes. This is correct because for [..., x, 1],
653+
// packing on dim 1 is equivalent to packing on dim x.
654+
SmallVector<int64_t> projectedInnerDimsPos =
655+
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
656+
657+
if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
658+
innerTileSizes)) {
659+
return failure();
660+
}
661+
// Expand the outer dims permutation with the associated source dims for the
662+
// new permutation after bubbling. This is because moving a collapsed dim is
663+
// equivalent to moving the associated source dims together.
664+
SmallVector<int64_t> newOuterDimsPerm;
665+
for (auto outerPos : outerDimsPerm) {
666+
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
667+
reassocIndices[outerPos].begin(),
668+
reassocIndices[outerPos].end());
669+
}
670+
671+
auto emptyOp = tensor::PackOp::createDestinationTensor(
672+
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
673+
projectedInnerDimsPos, newOuterDimsPerm);
674+
auto newPackOp = rewriter.create<tensor::PackOp>(
675+
packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
676+
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
677+
678+
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
679+
// First apply the permutation on the reassociations of the outer dims.
680+
// For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
681+
// -> [[0], [1, 2]]
682+
int64_t nextPos =
683+
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
684+
// Then add direct mapping for the inner tile dims.
685+
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
686+
newReassocIndices.push_back({nextPos});
687+
nextPos += 1;
688+
}
689+
690+
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
691+
collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
692+
rewriter.replaceOp(packOp, newCollapseOp);
693+
694+
return success();
695+
}
696+
697+
class BubbleUpPackOpThroughReshapeOp final
698+
: public OpRewritePattern<tensor::PackOp> {
699+
public:
700+
BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
701+
: OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
702+
703+
LogicalResult matchAndRewrite(tensor::PackOp packOp,
704+
PatternRewriter &rewriter) const override {
705+
Operation *srcOp = packOp.getSource().getDefiningOp();
706+
// Currently only support when the pack op is the only user.
707+
if (!srcOp || !(srcOp->getNumResults() == 1) ||
708+
!srcOp->getResult(0).hasOneUse()) {
709+
return failure();
710+
}
711+
// Currently only support static inner tile sizes.
712+
if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
713+
return ShapedType::isDynamic(size);
714+
})) {
715+
return failure();
716+
}
717+
718+
// User controlled propagation function.
719+
if (!controlFn(srcOp))
720+
return failure();
721+
722+
return TypeSwitch<Operation *, LogicalResult>(srcOp)
723+
.Case([&](tensor::CollapseShapeOp op) {
724+
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
725+
})
726+
.Default([](Operation *) { return failure(); });
727+
}
728+
729+
private:
730+
ControlPropagationFn controlFn;
731+
};
732+
733+
/// Push down unpack op through expand shape op when the packed dims can be
734+
/// projected to the dims after expanding. This is possible when the inner tile
735+
/// sizes can divide the projected dims.
736+
///
737+
/// For example:
738+
///
739+
/// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
740+
/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
741+
/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
742+
/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
743+
/// : tensor<?x256xf32> into tensor<?x256x256xf32>
744+
///
745+
/// can be transformed into:
746+
///
747+
/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
748+
/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
749+
/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
750+
/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
751+
/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
752+
static LogicalResult
753+
pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
754+
tensor::ExpandShapeOp expandOp,
755+
PatternRewriter &rewriter) {
756+
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
757+
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
758+
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
759+
760+
ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
761+
SmallVector<ReassociationIndices> reassocIndices =
762+
expandOp.getReassociationIndices();
763+
// Project inner tile pos to the dim pos after expanding. For example, if dims
764+
// [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
765+
// on dim y.
766+
//
767+
// Project to inner-most non-unit dims to increase the chance that they can be
768+
// divided by the inner tile sizes. This is correct because for [..., x, 1],
769+
// unpacking on dim 1 is equivalent to unpacking on dim x.
770+
SmallVector<int64_t> projectedInnerDimsPos =
771+
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
772+
773+
if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
774+
innerTileSizes)) {
775+
return failure();
776+
}
777+
// Expand the outer dims permutation with the associated expanded dims for the
778+
// new permutation after pushing. This is because moving a source dim is
779+
// equivalent to moving the associated expanded dims together.
780+
SmallVector<int64_t> newOuterDimsPerm;
781+
for (auto outerPos : outerDimsPerm) {
782+
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
783+
reassocIndices[outerPos].begin(),
784+
reassocIndices[outerPos].end());
785+
}
786+
787+
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
788+
// First apply the permutation on the reassociations of the outer dims.
789+
// For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
790+
// -> [[0], [1, 2]]
791+
int64_t nextPos =
792+
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
793+
// Then add direct mapping for the inner tile dims.
794+
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
795+
newReassocIndices.push_back({nextPos});
796+
nextPos += 1;
797+
}
798+
799+
RankedTensorType newExpandType =
800+
tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
801+
projectedInnerDimsPos, newOuterDimsPerm);
802+
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
803+
expandOp.getLoc(), newExpandType, unPackOp.getSource(),
804+
newReassocIndices);
805+
806+
auto emptyOp = tensor::UnPackOp::createDestinationTensor(
807+
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
808+
projectedInnerDimsPos, newOuterDimsPerm);
809+
auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
810+
unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
811+
projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
812+
rewriter.replaceOp(expandOp, newUnPackOp);
813+
814+
return success();
815+
}
816+
817+
class PushDownUnPackOpThroughReshapeOp final
818+
: public OpRewritePattern<tensor::UnPackOp> {
819+
public:
820+
PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
821+
ControlPropagationFn fun)
822+
: OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
823+
}
824+
825+
LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
826+
PatternRewriter &rewriter) const override {
827+
Value result = unPackOp.getResult();
828+
// Currently only support unpack op with the single user.
829+
if (!result.hasOneUse()) {
830+
return failure();
831+
}
832+
// Currently only support static inner tile sizes.
833+
if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
834+
return ShapedType::isDynamic(size);
835+
})) {
836+
return failure();
837+
}
838+
839+
Operation *consumerOp = *result.user_begin();
840+
// User controlled propagation function.
841+
if (!controlFn(consumerOp))
842+
return failure();
843+
844+
return TypeSwitch<Operation *, LogicalResult>(consumerOp)
845+
.Case([&](tensor::ExpandShapeOp op) {
846+
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
847+
})
848+
.Default([](Operation *) { return failure(); });
849+
}
850+
851+
private:
852+
ControlPropagationFn controlFn;
853+
};
854+
555855
// TODO: Relax this restriction. We should unpack a generic op also
556856
// in the presence of multiple unpack ops as producers.
557857
/// Return the unpacked operand, if present, for the current generic op.
@@ -774,6 +1074,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
7741074
const ControlPropagationFn &controlPackUnPackPropagation) {
7751075
patterns
7761076
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
777-
PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
1077+
BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1078+
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
7781079
patterns.getContext(), controlPackUnPackPropagation);
7791080
}

0 commit comments

Comments
 (0)