Skip to content

Commit 7af1229

Browse files
committed
Fixups
1 parent 09b0fcd commit 7af1229

File tree

4 files changed

+21
-14
lines changed

4 files changed

+21
-14
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
116116
MaskingOpInterface maskingOp,
117117
RewriterBase &rewriter);
118118

119-
/// Structure to hold the range [vscaleMin, vscaleMax] `vector.vscale` can take.
119+
// Structure to hold the range of `vector.vscale`.
120120
struct VscaleRange {
121121
unsigned vscaleMin;
122122
unsigned vscaleMax;

mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
19
#include "mlir/Dialect/Arith/IR/Arith.h"
210
#include "mlir/Dialect/Utils/StaticValueUtils.h"
311
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
@@ -105,10 +113,14 @@ void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
105113
return;
106114

107115
OpBuilder::InsertionGuard g(rewriter);
116+
117+
// Build worklist so we can safely insert new ops in
118+
// `resolveAllTrueCreateMaskOp()`.
108119
SmallVector<vector::CreateMaskOp> worklist;
109120
function.walk([&](vector::CreateMaskOp createMaskOp) {
110121
worklist.push_back(createMaskOp);
111122
});
123+
112124
rewriter.setInsertionPointToStart(&function.front());
113125
for (auto mask : worklist)
114126
(void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);

mlir/test/Dialect/Vector/eliminate-masks.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
1616
%extracted_slice_0 = tensor.extract_slice %tensor[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x1000xf32> to tensor<1x?xf32>
1717
%output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice_0) -> tensor<1x?xf32> {
1818
// 1. Extract a slice.
19-
%extracted_slice_1 = tensor.extract_slice %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
19+
%extracted_slice_1 = tensor.extract_slice %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
2020

2121
// 2. Create a mask for the slice.
2222
%dim_1 = tensor.dim %extracted_slice_1, %c0 : tensor<?xf32>
@@ -30,7 +30,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
3030
%write = vector.transfer_write %new_vec, %extracted_slice_1[%c0], %mask {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32>
3131

3232
// 5. Insert and yield the new tensor value.
33-
%result = tensor.insert_slice %write into %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
33+
%result = tensor.insert_slice %write into %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
3434
scf.yield %result : tensor<1x?xf32>
3535
}
3636
"test.some_use"(%output_tensor) : (tensor<1x?xf32>) -> ()
@@ -57,8 +57,8 @@ func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
5757
%mask = vector.create_mask %dim : vector<[4]xi1>
5858
"test.some_use"(%mask) : (vector<[4]xi1>) -> ()
5959
// !!! Here the size of the mask could shrink in the next iteration.
60-
%next_num_els = affine.min affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale]
61-
%new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_els] [1] : tensor<1000xf32> to tensor<?xf32>
60+
%next_num_elts = affine.min affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale]
61+
%new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_elts] [1] : tensor<1000xf32> to tensor<?xf32>
6262
scf.yield %new_extracted_slice : tensor<?xf32>
6363
}
6464
"test.some_use"(%slice) : (tensor<?xf32>) -> ()
@@ -110,8 +110,8 @@ func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>
110110
%c4 = arith.constant 4 : index
111111
%vscale = vector.vscale
112112
%c4_vscale = arith.muli %vscale, %c4 : index
113-
// This is _very_ simple but since addi is not a constant value bounds will
114-
// be used to resolve it.
113+
// This is _very_ simple but since tensor.dim is not a constant value bounds
114+
// will be used to resolve it.
115115
%dim = tensor.dim %tensor, %c0 : tensor<2x?xf32>
116116
%mask = vector.create_mask %dim, %c4_vscale : vector<3x[4]xi1>
117117
"test.some_use"(%mask) : (vector<3x[4]xi1>) -> ()

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -885,15 +885,10 @@ struct TestEliminateVectorMasks
885885
: PassWrapper(pass) {}
886886

887887
Option<unsigned> vscaleMin{
888-
*this, "vscale-min",
889-
llvm::cl::desc(
890-
"Minimum value `vector.vscale` can possibly be at runtime."),
888+
*this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
891889
llvm::cl::init(1)};
892-
893890
Option<unsigned> vscaleMax{
894-
*this, "vscale-max",
895-
llvm::cl::desc(
896-
"Maximum value `vector.vscale` can possibly be at runtime."),
891+
*this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
897892
llvm::cl::init(16)};
898893

899894
StringRef getArgument() const final { return "test-eliminate-vector-masks"; }

0 commit comments

Comments
 (0)