-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][spirv] Implement vector unrolling for convert-to-spirv
pass
#100138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-spirv Author: Angel Zhang (angelz913) ChangesDescriptionThis PR depends on #98337. It implements a minimal version of function body vector unrolling to convert vector types into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). The ops that are currently supported include those with elementwise traits (e.g. Future Plans
Patch is 32.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/100138.diff 16 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 748646e605827..b5bb2f42f2961 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -47,7 +47,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let options = [
Option<"runSignatureConversion", "run-signature-conversion", "bool",
/*default=*/"true",
- "Run function signature conversion to convert vector types">
+ "Run function signature conversion to convert vector types">,
+ Option<"runVectorUnrolling", "run-vector-unrolling", "bool",
+ /*default=*/"true",
+ "Run vector unrolling to convert vector types in function bodies">
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 9ad3d5fc85dd3..195fbd0d0cd58 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -189,6 +189,17 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
MemRefType baseType, Value basePtr,
ValueRange indices, Location loc, OpBuilder &builder);
+int getComputeVectorSize(int64_t size);
+
+// GetNativeVectorShape implementation for reduction ops.
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);
+
+// GetNativeVectorShape implementation for transpose ops.
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);
+
+// For general ops.
+std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 003a5feea9e9b..b82a244cfc973 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -17,6 +17,8 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -56,6 +58,78 @@ struct ConvertToSPIRVPass final
return signalPassFailure();
}
+ if (runVectorUnrolling) {
+
+ // Fold transpose ops if possible as we cannot unroll it later.
+ {
+ RewritePatternSet patterns(context);
+ vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // Unroll vectors to native vector size.
+ {
+ RewritePatternSet patterns(context);
+ auto options = vector::UnrollVectorOptions().setNativeShapeFn(
+ [=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+ populateVectorUnrollPatterns(patterns, options);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Convert transpose ops into extract and insert pairs, in preparation
+ // of further transformations to canonicalize/cancel.
+ {
+ RewritePatternSet patterns(context);
+ auto options =
+ vector::VectorTransformsOptions().setVectorTransposeLowering(
+ vector::VectorTransposeLowering::EltWise);
+ vector::populateVectorTransposeLoweringPatterns(patterns, options);
+ vector::populateVectorShapeCastLoweringPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // Run canonicalization to cast away leading size-1 dimensions.
+ {
+ RewritePatternSet patterns(context);
+
+ // Pull in casting way leading one dims to allow cancelling some
+ // read/write ops.
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+
+ // Decompose different rank insert_strided_slice and n-D
+ // extract_slided_slice.
+ vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+ patterns);
+ vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+
+ // Trimming leading unit dims may generate broadcast/shape_cast ops.
+ // Clean them up.
+ vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
+ vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Run all sorts of canonicalization patterns to clean up again.
+ {
+ RewritePatternSet patterns(context);
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+ vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+ vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+ vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+ }
+
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index bf5044437fd09..8470c7642e716 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -46,14 +46,6 @@ namespace {
// Utility functions
//===----------------------------------------------------------------------===//
-static int getComputeVectorSize(int64_t size) {
- for (int i : {4, 3, 2}) {
- if (size % i == 0)
- return i;
- }
- return 1;
-}
-
static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
if (vecType.isScalable()) {
@@ -62,8 +54,8 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
return std::nullopt;
}
SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
- std::optional<SmallVector<int64_t>> targetShape =
- SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
+ std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
+ 1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
if (!targetShape) {
LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
return std::nullopt;
@@ -1098,13 +1090,19 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
// the original operand of illegal type.
auto originalShape =
llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
- SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<int64_t> strides(originalShape.size(), 1);
+ SmallVector<int64_t> extractShape(originalShape.size(), 1);
+ extractShape.back() = targetShape->back();
SmallVector<Type> newTypes;
Value returnValue = returnOp.getOperand(origResultNo);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
Value result = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, returnValue, offsets, *targetShape, strides);
+ loc, returnValue, offsets, extractShape, strides);
+ SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
+ if (originalShape.size() > 1)
+ result =
+ rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
newOperands.push_back(result);
newTypes.push_back(unrolledType);
}
@@ -1285,6 +1283,53 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
builder);
}
+//===----------------------------------------------------------------------===//
+// Public functions for vector unrolling
+//===----------------------------------------------------------------------===//
+
+int mlir::spirv::getComputeVectorSize(int64_t size) {
+ for (int i : {4, 3, 2}) {
+ if (size % i == 0)
+ return i;
+ }
+ return 1;
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
+ VectorType srcVectorType = op.getSourceVectorType();
+ assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
+ int64_t vectorSize =
+ mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
+ return {vectorSize};
+}
+
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
+ VectorType vectorType = op.getResultVectorType();
+ SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
+ nativeSize.back() =
+ mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
+ return nativeSize;
+}
+
+std::optional<SmallVector<int64_t>>
+mlir::spirv::getNativeVectorShape(Operation *op) {
+ if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+ if (auto vecType = llvm::dyn_cast<VectorType>(op->getResultTypes()[0])) {
+ SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
+ nativeSize.back() =
+ mlir::spirv::getComputeVectorSize(vecType.getShape().back());
+ return nativeSize;
+ }
+ }
+
+ return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
+ .Case<vector::ReductionOp, vector::TransposeOp>(
+ [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
+ .Default([](Operation *) { return std::nullopt; });
+}
+
//===----------------------------------------------------------------------===//
// SPIR-V TypeConverter
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
index 1a844a7cd018b..6418e931f7460 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
//===----------------------------------------------------------------------===//
// arithmetic ops
diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
index 02b938be775a3..311174bef15ed 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @combined
// CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index 347d282f9ee0c..c018ccb924983 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
// -----
+// CHECK-LABEL: @simple_vector_2d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>
+ return %arg0 : vector<4x4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @vector_6and8
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
@@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i
// -----
+// CHECK-LABEL: @vector_2dand1d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>)
+func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32>
+ return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @reduction
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
index e1cb18aac5d01..f4b116849fa93 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/index.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @basic
func.func @basic(%a: index, %b: index) {
diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
index 58ec6ac61f6ac..246464928b81c 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @if_yield
// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
index c5e0e6603d94a..00556140c3018 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @return_scalar
// CHECK-SAME: %[[ARG0:.*]]: i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
index a83bfb6f405a0..f34ca01c94f00 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @ub
// CHECK: %[[UNDEF:.*]] = spirv.Undef : i32
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
new file mode 100644
index 0000000000000..54d9875002cb5
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt -test-spirv-vector-unrolling -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @vaddi
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>)
+func.func @vaddi(%arg0 : vector<6xi32>, %arg1 : vector<6xi32>) -> (vector<6xi32>) {
+ // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<3xi32>
+ // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<3xi32>
+ // CHECK: return %[[ADD0]], %[[ADD1]] : vector<3xi32>, vector<3xi32>
+ %0 = arith.addi %arg0, %arg1 : vector<6xi32>
+ return %0 : vector<6xi32>
+}
+
+// CHECK-LABEL: @vaddi_2d
+// CHECK-SAME: (%[[ARG0:...
[truncated]
|
convert-to-spirv
passconvert-to-spirv
pass
mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp
Outdated
Show resolved
Hide resolved
b41a563
to
9c312f1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nit
Co-authored-by: Jakub Kuderski <[email protected]>
…100138) Summary: ### Description This PR builds on #99872. It implements a minimal version of function body vector unrolling to convert vector types into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). The ops that are currently supported include those with elementwise traits (e.g. `arith.addi`), `vector.reduction` and `vector.transpose`. This PR also includes new LIT tests that only check for vector unrolling. ### Future Plans - Support more ops --------- Co-authored-by: Jakub Kuderski <[email protected]> Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250596
Summary: This PR updates CMake and Bazel dependencies for #100138. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250565
Description
This PR builds on #99872. It implements a minimal version of function body vector unrolling to convert vector types into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). The ops that are currently supported include those with elementwise traits (e.g.
arith.addi
),vector.reduction
andvector.transpose
. This PR also includes new LIT tests that only check for vector unrolling.Future Plans