Skip to content

Commit 357e380

Browse files
[mlir][vector] Prevent folding non memref-type gather into maskedload (#135371)
This patch fixes an issue in the FoldContiguousGather pattern which was incorrectly folding vector.gather operations with contiguous indices into vector.maskedload operations regardless of the base operand type. While vector.gather operations can work on both tensor and memref types, vector.maskedload operations are only valid for memref types. The pattern was incorrectly lowering a tensor-based gather into a masked-load, which is invalid. This fix adds a type check to ensure the pattern only applies to memref-based gather operations. Co-authored-by: Sagar Kulkarni <[email protected]>
1 parent 54e70ac commit 357e380

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5348,6 +5348,9 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
53485348
using OpRewritePattern::OpRewritePattern;
53495349
LogicalResult matchAndRewrite(GatherOp op,
53505350
PatternRewriter &rewriter) const override {
5351+
if (!op.getBase().getType().isa<MemRefType>())
5352+
return rewriter.notifyMatchFailure(op, "base must be of memref type");
5353+
53515354
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
53525355
return failure();
53535356

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3198,6 +3198,19 @@ func.func @contiguous_gather_step(%base: memref<?xf32>,
31983198

31993199
// -----
32003200

3201+
// CHECK-LABEL: @no_fold_contiguous_gather_tensor
3202+
func.func @no_fold_contiguous_gather_tensor(%base: tensor<8xf32>, %mask: vector<4xi1>, %pass_thru: vector<4xf32>) -> vector<4xf32> {
3203+
%c0 = arith.constant 0 : index
3204+
%indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
3205+
// CHECK: vector.gather
3206+
// CHECK-NOT: vector.maskedload
3207+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru :
3208+
tensor<8xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
3209+
return %0 : vector<4xf32>
3210+
}
3211+
3212+
// -----
3213+
32013214
// CHECK-LABEL: @gather_broadcast(
32023215
// TODO: Broadcast is not supported yet
32033216
// CHECK: %[[R:.*]] = vector.gather

0 commit comments

Comments
 (0)