Skip to content

Commit 616aff1

Browse files
authored
[mlir] IntegerRangeAnalysis: handle vector types in getDestWidth() (#114898)
PR #112292 added support for vectors to the integer range inference interface and analysis, but didn't update the getDestWidth() method. This caused crashes when trying to infer the ranges of `arith.extsi` with vector inputs, as the code would try to sign-extend a N-bit value to a 0-bit one, which would assert and crash. This commit fixes the issue by adding a getElementTypeOrSelf().
1 parent 52624d7 commit 616aff1

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mlir/lib/Interfaces/InferIntRangeInterface.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Interfaces/InferIntRangeInterface.h"
1010
#include "mlir/IR/BuiltinTypes.h"
11+
#include "mlir/IR/TypeUtilities.h"
1112
#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
1213
#include <optional>
1314

@@ -28,6 +29,7 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
2829
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
2930

3031
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
32+
type = getElementTypeOrSelf(type);
3133
if (type.isIndex())
3234
return IndexType::kInternalStorageBitWidth;
3335
if (auto integerType = dyn_cast<IntegerType>(type))

mlir/test/Dialect/Vector/int-range-interface.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,20 @@ func.func @vector_insertelement() -> vector<4xindex> {
9696

9797
// CHECK-LABEL: func @test_loaded_vector_extract
9898
// No bounds
99-
// CHECK: test.reflect_bounds %{{.*}} : i32
99+
// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32
100100
func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
101101
%c0 = arith.constant 0 : index
102102
%v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32>
103103
%e = vector.extract %v[0] : i32 from vector<4xi32>
104104
%bounds = test.reflect_bounds %e : i32
105105
func.return %bounds : i32
106106
}
107+
108+
// CHECK-LABEL: func @test_vector_extsi
109+
// CHECK: test.reflect_bounds {smax = 5 : si32, smin = 1 : si32, umax = 5 : ui32, umin = 1 : ui32}
110+
func.func @test_vector_extsi() -> vector<2xi32> {
111+
%0 = test.with_bounds {smax = 5 : si8, smin = 1 : si8, umax = 5 : ui8, umin = 1 : ui8 } : vector<2xi8>
112+
%1 = arith.extsi %0 : vector<2xi8> to vector<2xi32>
113+
%2 = test.reflect_bounds %1 : vector<2xi32>
114+
func.return %2 : vector<2xi32>
115+
}

0 commit comments

Comments
 (0)