Skip to content

[mlir][spirv] Fix function signature legalization for n-D vectors #99872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

angelz913
Copy link
Contributor

@angelz913 angelz913 commented Jul 22, 2024

This PR makes a fix to the function signature vector unrolling for SPIR-V (introduced in #98337).

Issue

vector.extract_strided_slice op requires its offsets, sizes and strides operands to be of the same length. The original implementation did not consider when the source vector has rank greater than 1, so the offsets created using StaticTileOffsetRange will have length > 1, while the size and strides still have length of 1. This caused assertion failures when creating the vector.extract_strided_slice ops for vectors in function results.

Fix

This PR addresses the issue by making the sizes and strides operands to match the length of offsets (equal to the rank of the original vector). An extract vector.extract op is also created to trim potential leading ones in the result from the vector.extract_strided_slice ops.

@llvmbot
Copy link
Member

llvmbot commented Jul 22, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Angel Zhang (angelz913)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/99872.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-2)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir (+44)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index bf5044437fd09..c146589612b5e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1098,13 +1098,19 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
       // the original operand of illegal type.
       auto originalShape =
           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
-      SmallVector<int64_t> strides(targetShape->size(), 1);
+      SmallVector<int64_t> strides(originalShape.size(), 1);
+      SmallVector<int64_t> extractShape(originalShape.size(), 1);
+      extractShape.back() = targetShape->back();
       SmallVector<Type> newTypes;
       Value returnValue = returnOp.getOperand(origResultNo);
       for (SmallVector<int64_t> offsets :
            StaticTileOffsetRange(originalShape, *targetShape)) {
         Value result = rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, returnValue, offsets, *targetShape, strides);
+            loc, returnValue, offsets, extractShape, strides);
+        SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
+        if (originalShape.size() > 1)
+          result =
+              rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
         newOperands.push_back(result);
         newTypes.push_back(unrolledType);
       }
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index 347d282f9ee0c..c018ccb924983 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
 
 // -----
 
+// CHECK-LABEL: @simple_vector_2d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>
+  return %arg0 : vector<4x4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @vector_6and8
 // CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
 func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
@@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i
 
 // -----
 
+// CHECK-LABEL: @vector_2dand1d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>)
+func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[EXTRACT0:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT0_1:.*]]  = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: %[[EXTRACT1:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT1_1:.*]]  = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: %[[EXTRACT2:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT2_1:.*]]  = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: %[[EXTRACT3:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT3_1:.*]]  = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32>
+  return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @reduction
 // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
 func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what the issue is and what this PR does to fix that?

@angelz913
Copy link
Contributor Author

Can you explain what the issue is and what this PR does to fix that?

Updated description

kuhar added a commit that referenced this pull request Jul 24, 2024
…100138)

### Description
This PR builds on #99872. It implements a minimal version of function
body vector unrolling to convert vector types into 1D and with a size
supported by SPIR-V (2, 3 or 4 depending on the original dimension). The
ops that are currently supported include those with elementwise traits
(e.g. `arith.addi`), `vector.reduction` and `vector.transpose`. This PR
also includes new LIT tests that only check for vector unrolling.

### Future Plans
- Support more ops

---------

Co-authored-by: Jakub Kuderski <[email protected]>
@angelz913
Copy link
Contributor Author

The changes are included in #100138.

@angelz913 angelz913 closed this Jul 24, 2024
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…100138)

Summary:
### Description
This PR builds on #99872. It implements a minimal version of function
body vector unrolling to convert vector types into 1D and with a size
supported by SPIR-V (2, 3 or 4 depending on the original dimension). The
ops that are currently supported include those with elementwise traits
(e.g. `arith.addi`), `vector.reduction` and `vector.transpose`. This PR
also includes new LIT tests that only check for vector unrolling.

### Future Plans
- Support more ops

---------

Co-authored-by: Jakub Kuderski <[email protected]>

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250596
@angelz913 angelz913 deleted the func-signature-2d-vector-unrolling branch July 25, 2024 21:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants