-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][spirv] Fix function signature legalization for n-D vectors #99872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Angel Zhang (angelz913) ChangesFull diff: https://github.com/llvm/llvm-project/pull/99872.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index bf5044437fd09..c146589612b5e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1098,13 +1098,19 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
// the original operand of illegal type.
auto originalShape =
llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
- SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<int64_t> strides(originalShape.size(), 1);
+ SmallVector<int64_t> extractShape(originalShape.size(), 1);
+ extractShape.back() = targetShape->back();
SmallVector<Type> newTypes;
Value returnValue = returnOp.getOperand(origResultNo);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
Value result = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, returnValue, offsets, *targetShape, strides);
+ loc, returnValue, offsets, extractShape, strides);
+ SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
+ if (originalShape.size() > 1)
+ result =
+ rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
newOperands.push_back(result);
newTypes.push_back(unrolledType);
}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index 347d282f9ee0c..c018ccb924983 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
// -----
+// CHECK-LABEL: @simple_vector_2d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>
+ return %arg0 : vector<4x4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @vector_6and8
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
@@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i
// -----
+// CHECK-LABEL: @vector_2dand1d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>)
+func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32>
+ return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @reduction
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what the issue is and what this PR does to fix that?
Updated description |
…100138) ### Description This PR builds on #99872. It implements a minimal version of function body vector unrolling to convert vector types into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). The ops that are currently supported include those with elementwise traits (e.g. `arith.addi`), `vector.reduction` and `vector.transpose`. This PR also includes new LIT tests that only check for vector unrolling. ### Future Plans - Support more ops --------- Co-authored-by: Jakub Kuderski <[email protected]>
The changes are included in #100138. |
…100138) Summary: ### Description This PR builds on #99872. It implements a minimal version of function body vector unrolling to convert vector types into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). The ops that are currently supported include those with elementwise traits (e.g. `arith.addi`), `vector.reduction` and `vector.transpose`. This PR also includes new LIT tests that only check for vector unrolling. ### Future Plans - Support more ops --------- Co-authored-by: Jakub Kuderski <[email protected]> Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250596
This PR makes a fix to the function signature vector unrolling for SPIR-V (introduced in #98337).
Issue
vector.extract_strided_slice
op requires itsoffsets
,sizes
andstrides
operands to be of the same length. The original implementation did not consider when the source vector has rank greater than 1, so theoffsets
created usingStaticTileOffsetRange
will have length > 1, while thesize
andstrides
still have length of 1. This caused assertion failures when creating thevector.extract_strided_slice
ops for vectors in function results.Fix
This PR addresses the issue by making the
sizes
andstrides
operands to match the length ofoffsets
(equal to the rank of the original vector). An extractvector.extract
op is also created to trim potential leading ones in the result from thevector.extract_strided_slice
ops.