Skip to content

Commit a00cceb

Browse files
committed
add some more ops
1 parent 33a9734 commit a00cceb

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp :
347347

348348
def Vector_BroadcastOp :
349349
Vector_Op<"broadcast", [Pure,
350+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
350351
PredOpTrait<"source operand and result have same element type",
351352
TCresVTEtIsSameAsOpBase<0, 0>>]>,
352353
Arguments<(ins AnyType:$source)>,
@@ -813,6 +814,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
813814

814815
def Vector_InsertElementOp :
815816
Vector_Op<"insertelement", [Pure,
817+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
816818
TypesMatchWith<"source operand type matches element type of result",
817819
"result", "source",
818820
"::llvm::cast<VectorType>($_self).getElementType()">,
@@ -861,6 +863,7 @@ def Vector_InsertElementOp :
861863

862864
def Vector_InsertOp :
863865
Vector_Op<"insert", [Pure,
866+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
864867
PredOpTrait<"source operand and result have same element type",
865868
TCresVTEtIsSameAsOpBase<0, 0>>,
866869
AllTypesMatch<["dest", "result"]>]> {

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,6 +2262,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
22622262
// BroadcastOp
22632263
//===----------------------------------------------------------------------===//
22642264

2265+
void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2266+
SetIntRangeFn setResultRanges) {
2267+
setResultRanges(getResult(), argRanges.front());
2268+
}
2269+
22652270
/// Return the dimensions of the result vector that were formerly ones in the
22662271
/// source tensor and thus correspond to "dim-1" broadcasting.
22672272
static llvm::SetVector<int64_t>
@@ -2723,6 +2728,11 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
27232728
// InsertElementOp
27242729
//===----------------------------------------------------------------------===//
27252730

2731+
void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2732+
SetIntRangeFn setResultRanges) {
2733+
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2734+
}
2735+
27262736
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
27272737
Value source, Value dest) {
27282738
build(builder, result, source, dest, {});
@@ -2772,6 +2782,11 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
27722782
// InsertOp
27732783
//===----------------------------------------------------------------------===//
27742784

2785+
void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2786+
SetIntRangeFn setResultRanges) {
2787+
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2788+
}
2789+
27752790
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
27762791
Value source, Value dest, int64_t position) {
27772792
build(builder, result, source, dest, ArrayRef<int64_t>{position});

mlir/test/Dialect/Vector/int-range-interface.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ func.func @vector_splat() -> vector<4xindex> {
2727
func.return %2 : vector<4xindex>
2828
}
2929

30+
// CHECK-LABEL: func @vector_broadcast
31+
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
32+
func.func @vector_broadcast() -> vector<4x16xindex> {
33+
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex>
34+
%1 = vector.broadcast %0 : vector<16xindex> to vector<4x16xindex>
35+
%2 = test.reflect_bounds %1 : vector<4x16xindex>
36+
func.return %2 : vector<4x16xindex>
37+
}
38+
3039
// CHECK-LABEL: func @vector_extract
3140
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
3241
func.func @vector_extract() -> index {
@@ -55,3 +64,24 @@ func.func @vector_add() -> vector<4xindex> {
5564
%3 = test.reflect_bounds %2 : vector<4xindex>
5665
func.return %3 : vector<4xindex>
5766
}
67+
68+
// CHECK-LABEL: func @vector_insert
69+
// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
70+
func.func @vector_insert() -> vector<4xindex> {
71+
%0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
72+
%1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
73+
%2 = vector.insert %1, %0[0] : index into vector<4xindex>
74+
%3 = test.reflect_bounds %2 : vector<4xindex>
75+
func.return %3 : vector<4xindex>
76+
}
77+
78+
// CHECK-LABEL: func @vector_insertelement
79+
// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
80+
func.func @vector_insertelement() -> vector<4xindex> {
81+
%c0 = arith.constant 0 : index
82+
%0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
83+
%1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
84+
%2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex>
85+
%3 = test.reflect_bounds %2 : vector<4xindex>
86+
func.return %3 : vector<4xindex>
87+
}

0 commit comments

Comments
 (0)