Skip to content

[mlir] IntegerRangeAnalysis: handle vector types in getDestWidth() #114898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 5, 2024

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented Nov 5, 2024

PR #112292 added support for vectors to the integer range inference interface and analysis, but didn't update the getDestWidth() method. This caused crashes when trying to infer the ranges of arith.extsi with vector inputs, as the code would try to sign-extend a N-bit value to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().

PR llvm#112292 added support for vectors to the integer range inference
interface and analysis, but didn't update the getDestWidth() method.
This caused crashes when trying to infer the ranges of `arith.extsi`
with vector inputs, as the code would try to sign-extend a N-bit value
to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().
@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-mlir-vector

Author: Krzysztof Drewniak (krzysz00)

Changes

PR #112292 added support for vectors to the integer range inference interface and analysis, but didn't update the getDestWidth() method. This caused crashes when trying to infer the ranges of arith.extsi with vector inputs, as the code would try to sign-extend a N-bit value to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().


Full diff: https://github.com/llvm/llvm-project/pull/114898.diff

2 Files Affected:

  • (modified) mlir/lib/Interfaces/InferIntRangeInterface.cpp (+2)
  • (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+10-1)
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index d879b93586899b..63658518dd4a3b 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
 #include <optional>
 
@@ -28,6 +29,7 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
 
 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
+  type = getElementTypeOrSelf(type);
   if (type.isIndex())
     return IndexType::kInternalStorageBitWidth;
   if (auto integerType = dyn_cast<IntegerType>(type))
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 29282423089ba6..09dfe932a52323 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -96,7 +96,7 @@ func.func @vector_insertelement() -> vector<4xindex> {
 
 // CHECK-LABEL: func @test_loaded_vector_extract
 // No bounds
-// CHECK: test.reflect_bounds %{{.*}} : i32
+// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32
 func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
   %c0 = arith.constant 0 : index
   %v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32>
@@ -104,3 +104,12 @@ func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
   %bounds = test.reflect_bounds %e : i32
   func.return %bounds : i32
 }
+
+// CHECK-LABEL: func @test_vector_extsi
+// CHECK: test.reflect_bounds {smax = 5 : si32, smin = 1 : si32, umax = 5 : ui32, umin = 1 : ui32}
+func.func @test_vector_extsi() -> vector<2xi32> {
+  %0 = test.with_bounds {smax = 5 : si8, smin = 1 : si8, umax = 5 : ui8, umin = 1 : ui8 } : vector<2xi8>
+  %1 = arith.extsi %0 : vector<2xi8> to vector<2xi32>
+  %2 = test.reflect_bounds %1 : vector<2xi32>
+  func.return %2 : vector<2xi32>
+}

@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

PR #112292 added support for vectors to the integer range inference interface and analysis, but didn't update the getDestWidth() method. This caused crashes when trying to infer the ranges of arith.extsi with vector inputs, as the code would try to sign-extend a N-bit value to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().


Full diff: https://github.com/llvm/llvm-project/pull/114898.diff

2 Files Affected:

  • (modified) mlir/lib/Interfaces/InferIntRangeInterface.cpp (+2)
  • (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+10-1)
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index d879b93586899b..63658518dd4a3b 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
 #include <optional>
 
@@ -28,6 +29,7 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
 
 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
+  type = getElementTypeOrSelf(type);
   if (type.isIndex())
     return IndexType::kInternalStorageBitWidth;
   if (auto integerType = dyn_cast<IntegerType>(type))
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 29282423089ba6..09dfe932a52323 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -96,7 +96,7 @@ func.func @vector_insertelement() -> vector<4xindex> {
 
 // CHECK-LABEL: func @test_loaded_vector_extract
 // No bounds
-// CHECK: test.reflect_bounds %{{.*}} : i32
+// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32
 func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
   %c0 = arith.constant 0 : index
   %v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32>
@@ -104,3 +104,12 @@ func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
   %bounds = test.reflect_bounds %e : i32
   func.return %bounds : i32
 }
+
+// CHECK-LABEL: func @test_vector_extsi
+// CHECK: test.reflect_bounds {smax = 5 : si32, smin = 1 : si32, umax = 5 : ui32, umin = 1 : ui32}
+func.func @test_vector_extsi() -> vector<2xi32> {
+  %0 = test.with_bounds {smax = 5 : si8, smin = 1 : si8, umax = 5 : ui8, umin = 1 : ui8 } : vector<2xi8>
+  %1 = arith.extsi %0 : vector<2xi8> to vector<2xi32>
+  %2 = test.reflect_bounds %1 : vector<2xi32>
+  func.return %2 : vector<2xi32>
+}

@krzysz00 krzysz00 merged commit 616aff1 into llvm:main Nov 5, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants