Skip to content

Commit e09aea1

Browse files
committed
[mlir] IntegerRangeAnalysis: add support for vector type
Treat integer range for vector type as union of ranges of individual elements. With this semantics, most arith ops on vectors will work out of the box, the only special handling needed for constants and vector elements manipulation ops. The end goal of these changes is to optimize vectorized index calculations.
1 parent 3484ed9 commit e09aea1

File tree

9 files changed

+119
-17
lines changed

9 files changed

+119
-17
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,21 @@
1313
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
1414
#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
1515

16-
include "mlir/Dialect/Vector/IR/Vector.td"
17-
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
1816
include "mlir/Dialect/Arith/IR/ArithBase.td"
1917
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
2018
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
2119
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
22-
include "mlir/IR/EnumAttr.td"
20+
include "mlir/Dialect/Vector/IR/Vector.td"
21+
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
2322
include "mlir/Interfaces/ControlFlowInterfaces.td"
2423
include "mlir/Interfaces/DestinationStyleOpInterface.td"
24+
include "mlir/Interfaces/InferIntRangeInterface.td"
2525
include "mlir/Interfaces/InferTypeOpInterface.td"
2626
include "mlir/Interfaces/SideEffectInterfaces.td"
2727
include "mlir/Interfaces/VectorInterfaces.td"
2828
include "mlir/Interfaces/ViewLikeInterface.td"
2929
include "mlir/IR/BuiltinAttributes.td"
30+
include "mlir/IR/EnumAttr.td"
3031

3132
// TODO: Add an attribute to specify a different algebra with operators other
3233
// than the current set: {*, +}.
@@ -627,6 +628,7 @@ def Vector_DeinterleaveOp :
627628

628629
def Vector_ExtractElementOp :
629630
Vector_Op<"extractelement", [Pure,
631+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
630632
TypesMatchWith<"result type matches element type of vector operand",
631633
"vector", "result",
632634
"::llvm::cast<VectorType>($_self).getElementType()">]>,
@@ -673,6 +675,7 @@ def Vector_ExtractElementOp :
673675

674676
def Vector_ExtractOp :
675677
Vector_Op<"extract", [Pure,
678+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
676679
PredOpTrait<"operand and result have same element type",
677680
TCresVTEtIsSameAsOpBase<0, 0>>,
678681
InferTypeOpAdaptorWithIsCompatible]> {
@@ -2795,6 +2798,7 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
27952798

27962799
def Vector_SplatOp : Vector_Op<"splat", [
27972800
Pure,
2801+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
27982802
TypesMatchWith<"operand type matches element type of result",
27992803
"aggregate", "input",
28002804
"::llvm::cast<VectorType>($_self).getElementType()">

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/Dialect.h"
2121
#include "mlir/IR/OpDefinition.h"
22+
#include "mlir/IR/TypeUtilities.h"
2223
#include "mlir/IR/Value.h"
2324
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2425
#include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -53,9 +54,10 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
5354
dialect = parent->getDialect();
5455
else
5556
dialect = value.getParentBlock()->getParentOp()->getDialect();
57+
58+
Type type = getElementTypeOrSelf(value);
5659
solver->propagateIfChanged(
57-
cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
58-
dialect)));
60+
cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
5961
}
6062

6163
LogicalResult IntegerRangeAnalysis::visitOperation(

mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,27 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
3535

3636
void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3737
SetIntRangeFn setResultRange) {
38-
auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
39-
if (constAttr) {
38+
if (auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
4039
const APInt &value = constAttr.getValue();
4140
setResultRange(getResult(), ConstantIntRanges::constant(value));
41+
return;
42+
}
43+
if (auto constAttr =
44+
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
45+
std::optional<ConstantIntRanges> result;
46+
for (APInt &&val : constAttr) {
47+
auto range = ConstantIntRanges::constant(val);
48+
if (!result) {
49+
result = range;
50+
} else {
51+
result = result->rangeUnion(range);
52+
}
53+
}
54+
55+
if (result)
56+
setResultRange(getResult(), *result);
57+
58+
return;
4259
}
4360
}
4461

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,27 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
5151
if (!maybeConstValue.has_value())
5252
return failure();
5353

54+
Type type = value.getType();
55+
Location loc = value.getLoc();
5456
Operation *maybeDefiningOp = value.getDefiningOp();
5557
Dialect *valueDialect =
5658
maybeDefiningOp ? maybeDefiningOp->getDialect()
5759
: value.getParentRegion()->getParentOp()->getDialect();
58-
Attribute constAttr =
59-
rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
60-
Operation *constOp = valueDialect->materializeConstant(
61-
rewriter, constAttr, value.getType(), value.getLoc());
60+
61+
Attribute constAttr;
62+
if (auto shaped = dyn_cast<ShapedType>(type)) {
63+
constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
64+
} else {
65+
constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
66+
}
67+
Operation *constOp =
68+
valueDialect->materializeConstant(rewriter, constAttr, type, loc);
6269
// Fall back to arith.constant if the dialect materializer doesn't know what
6370
// to do with an integer constant.
6471
if (!constOp)
6572
constOp = rewriter.getContext()
6673
->getLoadedDialect<ArithDialect>()
67-
->materializeConstant(rewriter, constAttr, value.getType(),
68-
value.getLoc());
74+
->materializeConstant(rewriter, constAttr, type, loc);
6975
if (!constOp)
7076
return failure();
7177

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,11 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
12211221
// ExtractElementOp
12221222
//===----------------------------------------------------------------------===//
12231223

1224+
void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1225+
SetIntRangeFn setResultRanges) {
1226+
setResultRanges(getResult(), argRanges.front());
1227+
}
1228+
12241229
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
12251230
Value source) {
12261231
result.addOperands({source});
@@ -1273,6 +1278,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
12731278
// ExtractOp
12741279
//===----------------------------------------------------------------------===//
12751280

1281+
void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1282+
SetIntRangeFn setResultRanges) {
1283+
setResultRanges(getResult(), argRanges.front());
1284+
}
1285+
12761286
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
12771287
Value source, int64_t position) {
12781288
build(builder, result, source, ArrayRef<int64_t>{position});
@@ -6423,6 +6433,11 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
64236433
return SplatElementsAttr::get(getType(), {constOperand});
64246434
}
64256435

6436+
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6437+
SetIntRangeFn setResultRanges) {
6438+
setResultRanges(getResult(), argRanges.front());
6439+
}
6440+
64266441
//===----------------------------------------------------------------------===//
64276442
// StepOp
64286443
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func.func @dead_code() {
100100
// Make sure not crash.
101101
// CHECK-LABEL: @no_integer_or_index
102102
func.func @no_integer_or_index() {
103-
// CHECK: arith.cmpi
103+
// CHECK: arith.constant dense<false> : vector<1xi1>
104104
%cst_0 = arith.constant dense<[0]> : vector<1xi32>
105105
%cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32>
106106
return
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
2+
3+
4+
// CHECK-LABEL: func @constant_vec
5+
// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
6+
func.func @constant_vec() -> vector<8xindex> {
7+
%0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
8+
%1 = test.reflect_bounds %0 : vector<8xindex>
9+
func.return %1 : vector<8xindex>
10+
}
11+
12+
// CHECK-LABEL: func @constant_splat
13+
// CHECK: test.reflect_bounds {smax = 3 : si32, smin = 3 : si32, umax = 3 : ui32, umin = 3 : ui32}
14+
func.func @constant_splat() -> vector<8xi32> {
15+
%0 = arith.constant dense<3> : vector<8xi32>
16+
%1 = test.reflect_bounds %0 : vector<8xi32>
17+
func.return %1 : vector<8xi32>
18+
}
19+
20+
21+
// CHECK-LABEL: func @vector_splat
22+
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
23+
func.func @vector_splat() -> vector<4xindex> {
24+
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
25+
%1 = vector.splat %0 : vector<4xindex>
26+
%2 = test.reflect_bounds %1 : vector<4xindex>
27+
func.return %2 : vector<4xindex>
28+
}
29+
30+
// CHECK-LABEL: func @vector_extract
31+
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
32+
func.func @vector_extract() -> index {
33+
%0 = test.with_bounds { umin = 5 : index, umax = 6 : index, smin = 5 : index, smax = 6 : index } : vector<4xindex>
34+
%1 = vector.extract %0[0] : index from vector<4xindex>
35+
%2 = test.reflect_bounds %1 : index
36+
func.return %2 : index
37+
}
38+
39+
// CHECK-LABEL: func @vector_extractelement
40+
// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index}
41+
func.func @vector_extractelement() -> index {
42+
%c0 = arith.constant 0 : index
43+
%0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
44+
%1 = vector.extractelement %0[%c0 : index] : vector<4xindex>
45+
%2 = test.reflect_bounds %1 : index
46+
func.return %2 : index
47+
}
48+
49+
// CHECK-LABEL: func @vector_add
50+
// CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index}
51+
func.func @vector_add() -> vector<4xindex> {
52+
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
53+
%1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
54+
%2 = arith.addi %0, %1 : vector<4xindex>
55+
%3 = test.reflect_bounds %2 : vector<4xindex>
56+
func.return %3 : vector<4xindex>
57+
}

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,12 +760,13 @@ void TestReflectBoundsOp::inferResultRanges(
760760
Type sIntTy, uIntTy;
761761
// For plain `IntegerType`s, we can derive the appropriate signed and unsigned
762762
// Types for the Attributes.
763-
if (auto intTy = llvm::dyn_cast<IntegerType>(getType())) {
763+
Type type = getElementTypeOrSelf(getType());
764+
if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
764765
unsigned bitwidth = intTy.getWidth();
765766
sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
766767
uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
767768
} else
768-
sIntTy = uIntTy = getType();
769+
sIntTy = uIntTy = type;
769770

770771
setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
771772
setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2781,7 +2781,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
27812781
//===----------------------------------------------------------------------===//
27822782
// Test InferIntRangeInterface
27832783
//===----------------------------------------------------------------------===//
2784-
def InferIntRangeType : AnyTypeOf<[AnyInteger, Index]>;
2784+
def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOf<[AnyInteger, Index]>]>;
27852785

27862786
def TestWithBoundsOp : TEST_Op<"with_bounds",
27872787
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,

0 commit comments

Comments
 (0)