Skip to content

Commit 03fcf33

Browse files
[mlir][Vector] Clean up populateVectorToLLVMConversionPatterns
1 parent 8cd8b50 commit 03fcf33

File tree

6 files changed

+41
-22
lines changed

6 files changed

+41
-22
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
292292
int64_t targetRank = 1,
293293
PatternBenefit benefit = 1);
294294

295+
/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
296+
/// n > 1.
297+
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
298+
295299
} // namespace vector
296300
} // namespace mlir
297301
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232
#include "mlir/Dialect/GPU/Transforms/Passes.h"
3333
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3434
#include "mlir/Dialect/MemRef/IR/MemRef.h"
35+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
3536
#include "mlir/IR/Attributes.h"
3637
#include "mlir/IR/Builders.h"
3738
#include "mlir/IR/BuiltinOps.h"
3839
#include "mlir/IR/BuiltinTypes.h"
40+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3941

4042
#include "llvm/ADT/STLExtras.h"
4143
#include "llvm/Support/Error.h"
@@ -522,6 +524,18 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
522524

523525
void GpuToLLVMConversionPass::runOnOperation() {
524526
MLIRContext *context = &getContext();
527+
528+
// Perform progressive lowering of vector transfer operations.
529+
{
530+
RewritePatternSet patterns(&getContext());
531+
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
532+
vector::populateVectorTransferLoweringPatterns(patterns,
533+
/*maxTransferRank=*/1);
534+
if (failed(
535+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
536+
return signalPassFailure();
537+
}
538+
525539
LowerToLLVMOptions options(context);
526540
options.useBarePtrCallConv = hostBarePtrCallConv;
527541
RewritePatternSet patterns(context);

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,16 +1475,17 @@ class VectorTypeCastOpConversion
14751475

14761476
/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
14771477
/// Non-scalable versions of this operation are handled in Vector Transforms.
1478-
class VectorCreateMaskOpRewritePattern
1479-
: public OpRewritePattern<vector::CreateMaskOp> {
1478+
class VectorCreateMaskOpConversion
1479+
: public OpConversionPattern<vector::CreateMaskOp> {
14801480
public:
1481-
explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1482-
bool enableIndexOpt)
1483-
: OpRewritePattern<vector::CreateMaskOp>(context),
1481+
explicit VectorCreateMaskOpConversion(MLIRContext *context,
1482+
bool enableIndexOpt)
1483+
: OpConversionPattern<vector::CreateMaskOp>(context),
14841484
force32BitVectorIndices(enableIndexOpt) {}
14851485

1486-
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1487-
PatternRewriter &rewriter) const override {
1486+
LogicalResult
1487+
matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1488+
ConversionPatternRewriter &rewriter) const override {
14881489
auto dstType = op.getType();
14891490
if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
14901491
return failure();
@@ -1495,7 +1496,7 @@ class VectorCreateMaskOpRewritePattern
14951496
loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
14961497
/*isScalable=*/true));
14971498
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1498-
op.getOperand(0));
1499+
adaptor.getOperands()[0]);
14991500
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
15001501
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
15011502
indices, bounds);
@@ -1896,16 +1897,19 @@ struct VectorScalableStepOpLowering
18961897

18971898
} // namespace
18981899

1900+
void mlir::vector::populateVectorRankReducingFMAPattern(
1901+
RewritePatternSet &patterns) {
1902+
patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
1903+
}
1904+
18991905
/// Populate the given list with patterns that convert from Vector to LLVM.
19001906
void mlir::populateVectorToLLVMConversionPatterns(
19011907
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
19021908
bool reassociateFPReductions, bool force32BitVectorIndices) {
1909+
// This function populates only ConversionPatterns, not RewritePatterns.
19031910
MLIRContext *ctx = converter.getDialect()->getContext();
1904-
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1905-
populateVectorInsertExtractStridedSliceTransforms(patterns);
1906-
populateVectorStepLoweringPatterns(patterns);
19071911
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1908-
patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1912+
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
19091913
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
19101914
VectorExtractElementOpConversion, VectorExtractOpConversion,
19111915
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
@@ -1922,8 +1926,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
19221926
MaskedReductionOpConversion, VectorInterleaveOpLowering,
19231927
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
19241928
VectorScalableStepOpLowering>(converter);
1925-
// Transfer ops with rank > 1 are handled by VectorToSCF.
1926-
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
19271929
}
19281930

19291931
void mlir::populateVectorToLLVMMatrixConversionPatterns(

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ struct ConvertVectorToLLVMPass
6262

6363
void ConvertVectorToLLVMPass::runOnOperation() {
6464
// Perform progressive lowering of operations on slices and all contraction
65-
// operations. Also materializes masks, applies folding and DCE.
65+
// operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
66+
// applies folding and DCE.
6667
{
6768
RewritePatternSet patterns(&getContext());
6869
populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -78,6 +79,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
7879
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
7980
populateVectorMaskMaterializationPatterns(patterns,
8081
force32BitVectorIndices);
82+
populateVectorInsertExtractStridedSliceTransforms(patterns);
83+
populateVectorStepLoweringPatterns(patterns);
84+
populateVectorRankReducingFMAPattern(patterns);
8185
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
8286
}
8387

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
22

33
module {
4-
func.func @func(%arg: vector<11xf32>) {
4+
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
55
%cst_41 = arith.constant dense<true> : vector<11xi1>
66
// CHECK: vector.mask
77
// CHECK-SAME: vector.yield %arg0
88
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
9-
return
9+
return %127 : vector<11xf32>
1010
}
1111
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,7 +2046,6 @@ func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32
20462046
// CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
20472047
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
20482048
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
2049-
// CHECK: %[[T2:.*]] = arith.constant 0.000000e+00 : f32
20502049
// CHECK: %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
20512050
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
20522051
// CHECK: %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
@@ -2067,7 +2066,6 @@ func.func @insert_strided_slice_f32_2d_into_3d(%b: vector<4x4xf32>, %c: vector<4
20672066
return %0 : vector<4x4x4xf32>
20682067
}
20692068
// CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d
2070-
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
20712069
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
20722070

20732071
// -----
@@ -2077,7 +2075,6 @@ func.func @insert_strided_slice_f32_2d_into_3d_scalable(%b: vector<4x[4]xf32>, %
20772075
return %0 : vector<4x4x[4]xf32>
20782076
}
20792077
// CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d_scalable
2080-
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
20812078
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
20822079

20832080
// -----
@@ -2087,7 +2084,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d(%b: vector<4x4xindex>, %c
20872084
return %0 : vector<4x4x4xindex>
20882085
}
20892086
// CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d
2090-
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
20912087
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
20922088

20932089
// -----
@@ -2097,7 +2093,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d_scalable(%b: vector<4x[4]
20972093
return %0 : vector<4x4x[4]xindex>
20982094
}
20992095
// CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d_scalable
2100-
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
21012096
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
21022097

21032098
// -----

0 commit comments

Comments
 (0)