Skip to content

Commit 34b56da

Browse files
committed
[mlir][vector] Prevent folding non memref-type gather into maskedload
This patch fixes an issue in the FoldContiguousGather pattern which was incorrectly folding vector.gather operations with contiguous indices into vector.maskedload operations regardless of the base operand type. While vector.gather operations can work on both tensor and memref types, vector.maskedload operations are only valid for memref types. The pattern was incorrectly lowering a tensor-based gather into a masked-load, which is invalid. This fix adds a type check to ensure the pattern only applies to memref-based gather operations.
1 parent 04c3898 commit 34b56da

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5340,6 +5340,9 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
53405340
using OpRewritePattern::OpRewritePattern;
53415341
LogicalResult matchAndRewrite(GatherOp op,
53425342
PatternRewriter &rewriter) const override {
5343+
if (!op.getBase().getType().isa<MemRefType>())
5344+
return rewriter.notifyMatchFailure(op, "base must be of memref type");
5345+
53435346
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
53445347
return failure();
53455348

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3149,6 +3149,19 @@ func.func @contiguous_gather_step(%base: memref<?xf32>,
31493149

31503150
// -----
31513151

3152+
// CHECK-LABEL: @no_fold_contiguous_gather_tensor
3153+
func.func @no_fold_contiguous_gather_tensor(%base: tensor<8xf32>, %mask: vector<4xi1>, %pass_thru: vector<4xf32>) -> vector<4xf32> {
3154+
%c0 = arith.constant 0 : index
3155+
%indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
3156+
// CHECK: vector.gather
3157+
// CHECK-NOT: vector.maskedload
3158+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru :
3159+
tensor<8xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
3160+
return %0 : vector<4xf32>
3161+
}
3162+
3163+
// -----
3164+
31523165
// CHECK-LABEL: @gather_broadcast(
31533166
// TODO: Broadcast is not supported yet
31543167
// CHECK: %[[R:.*]] = vector.gather

0 commit comments

Comments
 (0)