Skip to content

[mlir][spirv] Implement vector unrolling for convert-to-spirv pass #100138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let options = [
Option<"runSignatureConversion", "run-signature-conversion", "bool",
/*default=*/"true",
"Run function signature conversion to convert vector types">
"Run function signature conversion to convert vector types">,
Option<"runVectorUnrolling", "run-vector-unrolling", "bool",
/*default=*/"true",
"Run vector unrolling to convert vector types in function bodies">
];
}

Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/LogicalResult.h"

namespace mlir {

Expand Down Expand Up @@ -189,6 +191,25 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
MemRefType baseType, Value basePtr,
ValueRange indices, Location loc, OpBuilder &builder);

// Find the largest factor of size among {2,3,4} for the lowest dimension of
// the target shape.
int getComputeVectorSize(int64_t size);

// GetNativeVectorShape implementation for reduction ops.
SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);

// GetNativeVectorShape implementation for transpose ops.
SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);

// For general ops.
std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);

// Unroll vectors in function signatures to native size.
LogicalResult unrollVectorsInSignatures(Operation *op);

// Unroll vectors in function bodies to native size.
LogicalResult unrollVectorsInFuncBodies(Operation *op);

} // namespace spirv
} // namespace mlir

Expand Down
21 changes: 10 additions & 11 deletions mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -42,19 +44,16 @@ struct ConvertToSPIRVPass final
using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
MLIRContext *context = &getContext();

if (runSignatureConversion) {
// Unroll vectors in function signatures to native vector size.
RewritePatternSet patterns(context);
populateFuncOpVectorRewritePatterns(patterns);
populateReturnOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
}
// Unroll vectors in function signatures to native size.
if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op)))
return signalPassFailure();

// Unroll vectors in function bodies to native size.
if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
return signalPassFailure();

spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
Expand Down
139 changes: 127 additions & 12 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"

#include <functional>
Expand All @@ -46,14 +50,6 @@ namespace {
// Utility functions
//===----------------------------------------------------------------------===//

static int getComputeVectorSize(int64_t size) {
for (int i : {4, 3, 2}) {
if (size % i == 0)
return i;
}
return 1;
}

static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
if (vecType.isScalable()) {
Expand All @@ -62,8 +58,8 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
return std::nullopt;
}
SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
std::optional<SmallVector<int64_t>> targetShape =
SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
if (!targetShape) {
LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
return std::nullopt;
Expand Down Expand Up @@ -1098,13 +1094,20 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
// the original operand of illegal type.
auto originalShape =
llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
SmallVector<int64_t> strides(targetShape->size(), 1);
SmallVector<int64_t> strides(originalShape.size(), 1);
SmallVector<int64_t> extractShape(originalShape.size(), 1);
extractShape.back() = targetShape->back();
SmallVector<Type> newTypes;
Value returnValue = returnOp.getOperand(origResultNo);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
Value result = rewriter.create<vector::ExtractStridedSliceOp>(
loc, returnValue, offsets, *targetShape, strides);
loc, returnValue, offsets, extractShape, strides);
if (originalShape.size() > 1) {
SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
result =
rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
}
newOperands.push_back(result);
newTypes.push_back(unrolledType);
}
Expand Down Expand Up @@ -1285,6 +1288,118 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
builder);
}

//===----------------------------------------------------------------------===//
// Public functions for vector unrolling
//===----------------------------------------------------------------------===//

int mlir::spirv::getComputeVectorSize(int64_t size) {
for (int i : {4, 3, 2}) {
if (size % i == 0)
return i;
}
return 1;
}

SmallVector<int64_t>
mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
VectorType srcVectorType = op.getSourceVectorType();
assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
int64_t vectorSize =
mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
return {vectorSize};
}

SmallVector<int64_t>
mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
VectorType vectorType = op.getResultVectorType();
SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
nativeSize.back() =
mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
return nativeSize;
}

std::optional<SmallVector<int64_t>>
mlir::spirv::getNativeVectorShape(Operation *op) {
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
nativeSize.back() =
mlir::spirv::getComputeVectorSize(vecType.getShape().back());
return nativeSize;
}
}

return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
.Case<vector::ReductionOp, vector::TransposeOp>(
[](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
.Default([](Operation *) { return std::nullopt; });
}

LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
populateFuncOpVectorRewritePatterns(patterns);
populateReturnOpVectorRewritePatterns(patterns);
// We only want to apply signature conversion once to the existing func ops.
// Without specifying strictMode, the greedy pattern rewriter will keep
// looking for newly created func ops.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
return applyPatternsAndFoldGreedily(op, std::move(patterns), config);
}

LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
MLIRContext *context = op->getContext();

// Unroll vectors in function bodies to native vector size.
{
RewritePatternSet patterns(context);
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
[](auto op) { return mlir::spirv::getNativeVectorShape(op); });
populateVectorUnrollPatterns(patterns, options);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return failure();
}

// Convert transpose ops into extract and insert pairs, in preparation of
// further transformations to canonicalize/cancel.
{
RewritePatternSet patterns(context);
auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
vector::VectorTransposeLowering::EltWise);
vector::populateVectorTransposeLoweringPatterns(patterns, options);
vector::populateVectorShapeCastLoweringPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return failure();
}

// Run canonicalization to cast away leading size-1 dimensions.
{
RewritePatternSet patterns(context);

// We need to pull in casting way leading one dims.
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
vector::TransposeOp::getCanonicalizationPatterns(patterns, context);

// Decompose different rank insert_strided_slice and n-D
// extract_slided_slice.
vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
patterns);
vector::InsertOp::getCanonicalizationPatterns(patterns, context);
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);

// Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
// them up.
vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);

if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return failure();
}
return success();
}

//===----------------------------------------------------------------------===//
// SPIR-V TypeConverter
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/ConvertToSPIRV/arith.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s

//===----------------------------------------------------------------------===//
// arithmetic ops
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/ConvertToSPIRV/combined.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s

// CHECK-LABEL: @combined
// CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {

// -----

// CHECK-LABEL: @simple_vector_2d
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
// CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
// CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
// CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
// CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
// CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32>
// CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
// CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32>
// CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
// CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32>
// CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>
return %arg0 : vector<4x4xi32>
}

// -----

// CHECK-LABEL: @vector_6and8
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
Expand Down Expand Up @@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i

// -----

// CHECK-LABEL: @vector_2dand1d
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>)
func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32>
// CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
// CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
// CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
// CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
// CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32>
// CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
// CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32>
// CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
// CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32>
// CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32>
return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32>
}

// -----

// CHECK-LABEL: @reduction
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/ConvertToSPIRV/index.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s

// CHECK-LABEL: @basic
func.func @basic(%a: index, %b: index) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/ConvertToSPIRV/scf.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s

// CHECK-LABEL: @if_yield
// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/ConvertToSPIRV/simple.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s

// CHECK-LABEL: @return_scalar
// CHECK-SAME: %[[ARG0:.*]]: i32
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/ConvertToSPIRV/ub.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s
// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s

// CHECK-LABEL: @ub
// CHECK: %[[UNDEF:.*]] = spirv.Undef : i32
Expand Down
Loading
Loading