Skip to content

[mlir][vector] Add pattern to drop unit dims from vector.transpose #102017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 8, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Aug 5, 2024

Example:

BEFORE:

%transpose = vector.transpose %vector, [3, 0, 1, 2]
  : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>

AFTER:

%dropDims = vector.shape_cast %vector
  : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%transpose = vector.transpose %0, [1, 0]
  : vector<4x[4]xf32> to vector<[4]x4xf32>
%restoreDims = vector.shape_cast %transpose
  : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>

@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

Example:

BEFORE:

%transpose = vector.transpose %vector, [3, 0, 1, 2]
  : vector&lt;1x1x4x[4]xf32&gt; to vector&lt;[4]x1x1x4xf32&gt;

AFTER:

%dropDims = vector.shape_cast %vector
  : vector&lt;1x1x4x[4]xf32&gt; to vector&lt;4x[4]xf32&gt;
%transpose = vector.transpose %0, [1, 0]
  : vector&lt;4x[4]xf32&gt; to vector&lt;[4]x4xf32&gt;
%restoreDims = vector.shape_cast %transpose
  : vector&lt;[4]x4xf32&gt; to vector&lt;[4]x1x1x4xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/102017.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+68-2)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+33)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 40e04b76593a0..67c36bfa06ded 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -120,6 +120,11 @@ inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
   };
 }
 
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+inline auto getDims(VectorType vType) {
+  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
 /// A wrapper for getMixedSizes for vector.transfer_read and
 /// vector.transfer_write Ops (for source and destination, respectively).
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6777e589795c8..6b39cef7899d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+/// A pattern to drop unit dims from vector.transpose.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %transpose = vector.transpose %vector, [3, 0, 1, 2]
+///    : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %dropDims = vector.shape_cast %vector
+///    : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+///  %transpose = vector.transpose %0, [1, 0]
+///    : vector<4x[4]xf32> to vector<[4]x4xf32>
+///  %restoreDims = vector.shape_cast %transpose
+///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+///  ```
+struct DropUnitDimsFromTransposeOp final
+    : OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = op.getSourceVectorType();
+    VectorType sourceTypeWithoutUnitDims =
+        dropNonScalableUnitDimFromType(sourceType);
+
+    if (sourceType == sourceTypeWithoutUnitDims)
+      return failure();
+
+    // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+    auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
+    SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
+    int64_t droppedDims = 0;
+    for (auto [i, dim] : llvm::enumerate(sourceDims)) {
+      droppedDimsBefore[i] = droppedDims;
+      if (dim == std::make_tuple(1, false))
+        ++droppedDims;
+    }
+
+    // Drop unit dims from transpose permutation.
+    ArrayRef<int64_t> perm = op.getPermutation();
+    SmallVector<int64_t> newPerm;
+    for (int64_t idx : perm) {
+      if (sourceDims[idx] == std::make_tuple(1, false))
+        continue;
+      newPerm.push_back(idx - droppedDimsBefore[idx]);
+    }
+
+    auto loc = op.getLoc();
+    // Drop the unit dims via shape_cast.
+    auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
+        loc, sourceTypeWithoutUnitDims, op.getVector());
+    // Create the new transpose.
+    auto tranposeWithoutUnitDims =
+        rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
+    // Restore the unit dims via shape cast.
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        op, op.getResultVectorType(), tranposeWithoutUnitDims);
+
+    return failure();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1924,8 +1990,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
-      patterns.getContext(), benefit);
+  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
+               ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9d16aa46a9f2a..222a05ff70d02 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -700,3 +700,36 @@ func.func @negative_out_of_bound_transfer_write(
 }
 // CHECK:     func.func @negative_out_of_bound_transfer_write
 // CHECK-NOT:   memref.collapse_shape
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: DropUnitDimsFromTransposeOp]
+/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
+///----------------------------------------------------------------------------------------
+
+func.func @transpose_with_internal_unit_dims(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
+  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+  return %0 : vector<[4]x1x1x4xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
+// CHECK-SAME:                                               %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+// CHECK-NEXT:    %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+// CHECK-NEXT:    %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+// CHECK-NEXT:    %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+// CHECK-NEXT:    return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32>
+
+// -----
+
+func.func @transpose_with_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x1x1x1x4x1xf32> {
+  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+  return %0 : vector<[4]x1x1x1x4x1xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_with_units_dims_before_and_after(
+// CHECK-SAME:                                                        %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
+// CHECK-NEXT:    %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
+// CHECK-NEXT:    %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+// CHECK-NEXT:    %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x1x4x1xf32>
+// CHECK-NEXT:    return %[[RESTORE_DIMS]] : vector<[4]x1x1x1x4x1xf32>

Example:

BEFORE:
```mlir
%transpose = vector.transpose %vector, [3, 0, 1, 2]
  : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
```

AFTER:
```mlir
%dropDims = vector.shape_cast %vector
  : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%transpose = vector.transpose %0, [1, 0]
  : vector<4x[4]xf32> to vector<[4]x4xf32>
%restoreDims = vector.shape_cast %transpose
  : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
```
@banach-space
Copy link
Contributor

Adding DropUnitDimsFromTransposeOp to populateDropUnitDimWithShapeCastPatterns makes a lot of sense to me in terms of grouping these patterns.

How is this related to #100933? Also, let's make sure that somebody with SPIR-V lenses on takes a look :) CC @kuhar

@banach-space banach-space requested a review from kuhar August 6, 2024 12:22
@MacDue
Copy link
Member Author

MacDue commented Aug 7, 2024

This can supersede #100933

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, sorry, forgot to send yesterday.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't foresee this causing any issues related to the SPIR-V lowering.

@MacDue MacDue merged commit da8778e into llvm:main Aug 8, 2024
7 checks passed
@MacDue MacDue deleted the drop_tr_dims branch August 8, 2024 10:42
MacDue added a commit that referenced this pull request Aug 9, 2024
…spose(shape_cast) (#100731)" (#102457)

This reverts commit 88accd9.

This change can be dropped in favor of just #102017.
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks for bearing with me. Appreciate it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants