diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 40e04b76593a0..5f32aca88a273 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -120,6 +120,11 @@ inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) { }; } +/// Returns a range over the dims (size and scalability) of a VectorType. +inline auto getDims(VectorType vType) { + return llvm::zip_equal(vType.getShape(), vType.getScalableDims()); +} + /// A wrapper for getMixedSizes for vector.transfer_read and /// vector.transfer_write Ops (for source and destination, respectively). /// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 6777e589795c8..55c1c6bad9f2a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final } }; +/// A pattern to drop unit dims from vector.transpose. +/// +/// Example: +/// +/// BEFORE: +/// ```mlir +/// %transpose = vector.transpose %vector, [3, 0, 1, 2] +/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> +/// ``` +/// +/// AFTER: +/// ```mlir +/// %dropDims = vector.shape_cast %vector +/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> +/// %transpose = vector.transpose %0, [1, 0] +/// : vector<4x[4]xf32> to vector<[4]x4xf32> +/// %restoreDims = vector.shape_cast %transpose +/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> +/// ``` +struct DropUnitDimsFromTransposeOp final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + VectorType sourceType = op.getSourceVectorType(); + VectorType sourceTypeWithoutUnitDims = + dropNonScalableUnitDimFromType(sourceType); + + if (sourceType == sourceTypeWithoutUnitDims) + return failure(); + + // Construct a map from dimIdx -> number of dims dropped before dimIdx. + auto sourceDims = llvm::to_vector(vector::getDims(sourceType)); + SmallVector droppedDimsBefore(sourceType.getRank()); + int64_t droppedDims = 0; + for (auto [i, dim] : llvm::enumerate(sourceDims)) { + droppedDimsBefore[i] = droppedDims; + if (dim == std::make_tuple(1, false)) + ++droppedDims; + } + + // Drop unit dims from transpose permutation. + ArrayRef perm = op.getPermutation(); + SmallVector newPerm; + for (int64_t idx : perm) { + if (sourceDims[idx] == std::make_tuple(1, false)) + continue; + newPerm.push_back(idx - droppedDimsBefore[idx]); + } + + Location loc = op.getLoc(); + // Drop the unit dims via shape_cast. + auto dropDimsShapeCast = rewriter.create( + loc, sourceTypeWithoutUnitDims, op.getVector()); + // Create the new transpose. + auto tranposeWithoutUnitDims = + rewriter.create(loc, dropDimsShapeCast, newPerm); + // Restore the unit dims via shape cast. + rewriter.replaceOpWithNewOp( + op, op.getResultVectorType(), tranposeWithoutUnitDims); + + return failure(); + } +}; + /// Pattern to eliminate redundant zero-constants added to reduction operands. /// It's enough for there to be one initial zero value, so we can eliminate the /// extra ones that feed into `vector.reduction `. These get created by the @@ -1924,8 +1990,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, void mlir::vector::populateDropUnitDimWithShapeCastPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add( - patterns.getContext(), benefit); + patterns.add(patterns.getContext(), benefit); } void mlir::vector::populateBubbleVectorBitCastOpPatterns( diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 9d16aa46a9f2a..937dbf22bb713 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -700,3 +700,47 @@ func.func @negative_out_of_bound_transfer_write( } // CHECK: func.func @negative_out_of_bound_transfer_write // CHECK-NOT: memref.collapse_shape + +// ----- + +///---------------------------------------------------------------------------------------- +/// [Pattern: DropUnitDimsFromTransposeOp] +/// TODO: Move to a dedicated file - there's no "flattening" in the following tests +///---------------------------------------------------------------------------------------- + +func.func @transpose_with_internal_unit_dims(%vec: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> { + %res = vector.transpose %vec, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> + return %res : vector<[4]x1x1x4xf32> +} + +// CHECK-LABEL: func.func @transpose_with_internal_unit_dims( +// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>) +// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> +// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> +// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32> + +// ----- + +func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> vector<1x1x4x2x[1]xf32> +{ + %res = vector.transpose %vec, [4, 1, 3, 2, 0] : vector<[1]x1x2x4x1xf32> to vector<1x1x4x2x[1]xf32> + return %res: vector<1x1x4x2x[1]xf32> +} + +// CHECK-LABEL: func.func @transpose_with_scalable_unit_dims( +// CHECK-SAME: %[[VEC:.*]]: vector<[1]x1x2x4x1xf32>) +// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %[[VEC]] : vector<[1]x1x2x4x1xf32> to vector<[1]x2x4xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[DROP_DIMS]], [2, 1, 0] : vector<[1]x2x4xf32> to vector<4x2x[1]xf32> +// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %[[TRANSPOSE]] : vector<4x2x[1]xf32> to vector<1x1x4x2x[1]xf32> +// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<1x1x4x2x[1]xf32> + +// ----- + +func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> { + %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> + return %res : vector<4x3x2xf32> +} + +// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims +// CHECK-NOT: vector.shape_cast