|
17 | 17 | #include "mlir/Dialect/Utils/IndexingUtils.h"
|
18 | 18 | #include "mlir/IR/Dominance.h"
|
19 | 19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
| 20 | +#include "llvm/ADT/TypeSwitch.h" |
20 | 21 | #include "llvm/Support/Debug.h"
|
21 | 22 | #include <optional>
|
22 | 23 |
|
@@ -552,6 +553,305 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
|
552 | 553 | ControlPropagationFn controlFn;
|
553 | 554 | };
|
554 | 555 |
|
| 556 | +/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. |
| 557 | +/// |
| 558 | +/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and |
| 559 | +/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the |
| 560 | +/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most |
| 561 | +/// non-unit projected dims in pos [2, 3] is 2. |
| 562 | +/// |
| 563 | +/// If all candidates in a reassociation are unit dims, it chooses the |
| 564 | +/// inner-most dim pos. |
| 565 | +static SmallVector<int64_t> |
| 566 | +projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, |
| 567 | + ArrayRef<ReassociationIndices> reassocIndices, |
| 568 | + ArrayRef<int64_t> targetShape) { |
| 569 | + SmallVector<int64_t> projectedDimsPos; |
| 570 | + for (auto pos : dimsPos) { |
| 571 | + // In the case all dims are unit, this will return the inner-most one. |
| 572 | + int64_t projectedPos = reassocIndices[pos].back(); |
| 573 | + for (auto i : llvm::reverse(reassocIndices[pos])) { |
| 574 | + int64_t dim = targetShape[i]; |
| 575 | + if (dim > 1 || ShapedType::isDynamic(dim)) { |
| 576 | + projectedPos = i; |
| 577 | + break; |
| 578 | + } |
| 579 | + } |
| 580 | + projectedDimsPos.push_back(projectedPos); |
| 581 | + } |
| 582 | + return projectedDimsPos; |
| 583 | +} |
| 584 | + |
| 585 | +/// Check if all dims in dimsPos are divisible by the corresponding tile sizes. |
| 586 | +static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, |
| 587 | + ArrayRef<int64_t> shape, |
| 588 | + ArrayRef<int64_t> tileSizes) { |
| 589 | + for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { |
| 590 | + int64_t dim = shape[pos]; |
| 591 | + if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) |
| 592 | + return false; |
| 593 | + } |
| 594 | + return true; |
| 595 | +} |
| 596 | + |
| 597 | +/// Permutate the reassociation indices and reindex them in the sequence order. |
| 598 | +/// Returns the next dim pos in the sequence. |
| 599 | +/// |
| 600 | +/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it |
| 601 | +/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into |
| 602 | +/// [[0], [1, 2]]. |
| 603 | +static int64_t applyPermutationAndReindexReassoc( |
| 604 | + SmallVector<ReassociationIndices> &reassocIndices, |
| 605 | + ArrayRef<int64_t> permutation) { |
| 606 | + applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); |
| 607 | + int64_t nextPos = 0; |
| 608 | + for (ReassociationIndices &indices : reassocIndices) { |
| 609 | + for (auto &index : indices) { |
| 610 | + index = nextPos; |
| 611 | + nextPos += 1; |
| 612 | + } |
| 613 | + } |
| 614 | + return nextPos; |
| 615 | +} |
| 616 | + |
| 617 | +/// Bubble up pack op through collapse shape op when the packed dims can be |
| 618 | +/// projected to the dims before collapsing. This is possible when the inner |
| 619 | +/// tile sizes can divide the projected dims. |
| 620 | +/// |
| 621 | +/// For example: |
| 622 | +/// |
| 623 | +/// %collapsed = tensor.collapse_shape %in [[0, 1], 2] |
| 624 | +/// : tensor<?x16x4xf32> into tensor<?x4xf32> |
| 625 | +/// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] |
| 626 | +/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty |
| 627 | +/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> |
| 628 | +/// |
| 629 | +/// can be transformed into: |
| 630 | +/// |
| 631 | +/// %pack = tensor.pack %in outer_dims_perm = [1, 2] |
| 632 | +/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty |
| 633 | +/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> |
| 634 | +/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] |
| 635 | +/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> |
| 636 | +static LogicalResult |
| 637 | +bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, |
| 638 | + tensor::PackOp packOp, |
| 639 | + PatternRewriter &rewriter) { |
| 640 | + SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); |
| 641 | + ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); |
| 642 | + ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
| 643 | + |
| 644 | + ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); |
| 645 | + SmallVector<ReassociationIndices> reassocIndices = |
| 646 | + collapseOp.getReassociationIndices(); |
| 647 | + // Project inner tile pos to the dim pos before collapsing. For example, if |
| 648 | + // dims [x, y] is collapsed into [z], packing on dim z can be projected back |
| 649 | + // to pack on dim y. |
| 650 | + // |
| 651 | + // Project to inner-most non-unit dims to increase the chance that they can be |
| 652 | + // divided by the inner tile sizes. This is correct because for [..., x, 1], |
| 653 | + // packing on dim 1 is equivalent to packing on dim x. |
| 654 | + SmallVector<int64_t> projectedInnerDimsPos = |
| 655 | + projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); |
| 656 | + |
| 657 | + if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, |
| 658 | + innerTileSizes)) { |
| 659 | + return failure(); |
| 660 | + } |
| 661 | + // Expand the outer dims permutation with the associated source dims for the |
| 662 | + // new permutation after bubbling. This is because moving a collapsed dim is |
| 663 | + // equivalent to moving the associated source dims together. |
| 664 | + SmallVector<int64_t> newOuterDimsPerm; |
| 665 | + for (auto outerPos : outerDimsPerm) { |
| 666 | + newOuterDimsPerm.insert(newOuterDimsPerm.end(), |
| 667 | + reassocIndices[outerPos].begin(), |
| 668 | + reassocIndices[outerPos].end()); |
| 669 | + } |
| 670 | + |
| 671 | + auto emptyOp = tensor::PackOp::createDestinationTensor( |
| 672 | + rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), |
| 673 | + projectedInnerDimsPos, newOuterDimsPerm); |
| 674 | + auto newPackOp = rewriter.create<tensor::PackOp>( |
| 675 | + packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, |
| 676 | + packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); |
| 677 | + |
| 678 | + SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; |
| 679 | + // First apply the permutation on the reassociations of the outer dims. |
| 680 | + // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] |
| 681 | + // -> [[0], [1, 2]] |
| 682 | + int64_t nextPos = |
| 683 | + applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); |
| 684 | + // Then add direct mapping for the inner tile dims. |
| 685 | + for (size_t i = 0; i < innerDimsPos.size(); ++i) { |
| 686 | + newReassocIndices.push_back({nextPos}); |
| 687 | + nextPos += 1; |
| 688 | + } |
| 689 | + |
| 690 | + auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( |
| 691 | + collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); |
| 692 | + rewriter.replaceOp(packOp, newCollapseOp); |
| 693 | + |
| 694 | + return success(); |
| 695 | +} |
| 696 | + |
| 697 | +class BubbleUpPackOpThroughReshapeOp final |
| 698 | + : public OpRewritePattern<tensor::PackOp> { |
| 699 | +public: |
| 700 | + BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) |
| 701 | + : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} |
| 702 | + |
| 703 | + LogicalResult matchAndRewrite(tensor::PackOp packOp, |
| 704 | + PatternRewriter &rewriter) const override { |
| 705 | + Operation *srcOp = packOp.getSource().getDefiningOp(); |
| 706 | + // Currently only support when the pack op is the only user. |
| 707 | + if (!srcOp || !(srcOp->getNumResults() == 1) || |
| 708 | + !srcOp->getResult(0).hasOneUse()) { |
| 709 | + return failure(); |
| 710 | + } |
| 711 | + // Currently only support static inner tile sizes. |
| 712 | + if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) { |
| 713 | + return ShapedType::isDynamic(size); |
| 714 | + })) { |
| 715 | + return failure(); |
| 716 | + } |
| 717 | + |
| 718 | + // User controlled propagation function. |
| 719 | + if (!controlFn(srcOp)) |
| 720 | + return failure(); |
| 721 | + |
| 722 | + return TypeSwitch<Operation *, LogicalResult>(srcOp) |
| 723 | + .Case([&](tensor::CollapseShapeOp op) { |
| 724 | + return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); |
| 725 | + }) |
| 726 | + .Default([](Operation *) { return failure(); }); |
| 727 | + } |
| 728 | + |
| 729 | +private: |
| 730 | + ControlPropagationFn controlFn; |
| 731 | +}; |
| 732 | + |
| 733 | +/// Push down unpack op through expand shape op when the packed dims can be |
| 734 | +/// projected to the dims after expanding. This is possible when the inner tile |
| 735 | +/// sizes can divide the projected dims. |
| 736 | +/// |
| 737 | +/// For example: |
| 738 | +/// |
| 739 | +/// %unpack = tensor.unpack %in outer_dims_perm = [0, 1] |
| 740 | +/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty |
| 741 | +/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> |
| 742 | +/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] |
| 743 | +/// : tensor<?x256xf32> into tensor<?x256x256xf32> |
| 744 | +/// |
| 745 | +/// can be transformed into: |
| 746 | +/// |
| 747 | +/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] |
| 748 | +/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> |
| 749 | +/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2] |
| 750 | +/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty |
| 751 | +/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> |
| 752 | +static LogicalResult |
| 753 | +pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp, |
| 754 | + tensor::ExpandShapeOp expandOp, |
| 755 | + PatternRewriter &rewriter) { |
| 756 | + SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); |
| 757 | + ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); |
| 758 | + ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); |
| 759 | + |
| 760 | + ArrayRef<int64_t> dstShape = expandOp.getType().getShape(); |
| 761 | + SmallVector<ReassociationIndices> reassocIndices = |
| 762 | + expandOp.getReassociationIndices(); |
| 763 | + // Project inner tile pos to the dim pos after expanding. For example, if dims |
| 764 | + // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack |
| 765 | + // on dim y. |
| 766 | + // |
| 767 | + // Project to inner-most non-unit dims to increase the chance that they can be |
| 768 | + // divided by the inner tile sizes. This is correct because for [..., x, 1], |
| 769 | + // unpacking on dim 1 is equivalent to unpacking on dim x. |
| 770 | + SmallVector<int64_t> projectedInnerDimsPos = |
| 771 | + projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); |
| 772 | + |
| 773 | + if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, |
| 774 | + innerTileSizes)) { |
| 775 | + return failure(); |
| 776 | + } |
| 777 | + // Expand the outer dims permutation with the associated expanded dims for the |
| 778 | + // new permutation after pushing. This is because moving a source dim is |
| 779 | + // equivalent to moving the associated expanded dims together. |
| 780 | + SmallVector<int64_t> newOuterDimsPerm; |
| 781 | + for (auto outerPos : outerDimsPerm) { |
| 782 | + newOuterDimsPerm.insert(newOuterDimsPerm.end(), |
| 783 | + reassocIndices[outerPos].begin(), |
| 784 | + reassocIndices[outerPos].end()); |
| 785 | + } |
| 786 | + |
| 787 | + SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; |
| 788 | + // First apply the permutation on the reassociations of the outer dims. |
| 789 | + // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] |
| 790 | + // -> [[0], [1, 2]] |
| 791 | + int64_t nextPos = |
| 792 | + applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); |
| 793 | + // Then add direct mapping for the inner tile dims. |
| 794 | + for (size_t i = 0; i < innerDimsPos.size(); ++i) { |
| 795 | + newReassocIndices.push_back({nextPos}); |
| 796 | + nextPos += 1; |
| 797 | + } |
| 798 | + |
| 799 | + RankedTensorType newExpandType = |
| 800 | + tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes, |
| 801 | + projectedInnerDimsPos, newOuterDimsPerm); |
| 802 | + auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( |
| 803 | + expandOp.getLoc(), newExpandType, unPackOp.getSource(), |
| 804 | + newReassocIndices); |
| 805 | + |
| 806 | + auto emptyOp = tensor::UnPackOp::createDestinationTensor( |
| 807 | + rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), |
| 808 | + projectedInnerDimsPos, newOuterDimsPerm); |
| 809 | + auto newUnPackOp = rewriter.create<tensor::UnPackOp>( |
| 810 | + unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, |
| 811 | + projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); |
| 812 | + rewriter.replaceOp(expandOp, newUnPackOp); |
| 813 | + |
| 814 | + return success(); |
| 815 | +} |
| 816 | + |
| 817 | +class PushDownUnPackOpThroughReshapeOp final |
| 818 | + : public OpRewritePattern<tensor::UnPackOp> { |
| 819 | +public: |
| 820 | + PushDownUnPackOpThroughReshapeOp(MLIRContext *context, |
| 821 | + ControlPropagationFn fun) |
| 822 | + : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) { |
| 823 | + } |
| 824 | + |
| 825 | + LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp, |
| 826 | + PatternRewriter &rewriter) const override { |
| 827 | + Value result = unPackOp.getResult(); |
| 828 | + // Currently only support unpack op with the single user. |
| 829 | + if (!result.hasOneUse()) { |
| 830 | + return failure(); |
| 831 | + } |
| 832 | + // Currently only support static inner tile sizes. |
| 833 | + if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) { |
| 834 | + return ShapedType::isDynamic(size); |
| 835 | + })) { |
| 836 | + return failure(); |
| 837 | + } |
| 838 | + |
| 839 | + Operation *consumerOp = *result.user_begin(); |
| 840 | + // User controlled propagation function. |
| 841 | + if (!controlFn(consumerOp)) |
| 842 | + return failure(); |
| 843 | + |
| 844 | + return TypeSwitch<Operation *, LogicalResult>(consumerOp) |
| 845 | + .Case([&](tensor::ExpandShapeOp op) { |
| 846 | + return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter); |
| 847 | + }) |
| 848 | + .Default([](Operation *) { return failure(); }); |
| 849 | + } |
| 850 | + |
| 851 | +private: |
| 852 | + ControlPropagationFn controlFn; |
| 853 | +}; |
| 854 | + |
555 | 855 | // TODO: Relax this restriction. We should unpack a generic op also
|
556 | 856 | // in the presence of multiple unpack ops as producers.
|
557 | 857 | /// Return the unpacked operand, if present, for the current generic op.
|
@@ -774,6 +1074,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
|
774 | 1074 | const ControlPropagationFn &controlPackUnPackPropagation) {
|
775 | 1075 | patterns
|
776 | 1076 | .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
|
777 |
| - PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>( |
| 1077 | + BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, |
| 1078 | + PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( |
778 | 1079 | patterns.getContext(), controlPackUnPackPropagation);
|
779 | 1080 | }
|
0 commit comments