Skip to content

[mlir] IntegerRangeAnalysis: add support for vector type #112292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 1, 2024

Conversation

Hardcode84
Copy link
Contributor

Treat integer range for vector type as union of ranges of individual elements. With this semantics, most arith ops on vectors will work out of the box, the only special handling needed for constants and vector elements manipulation ops.

The end goal of these changes is to be able to optimize vectorized index calculations.

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir-vector

Author: Ivan Butygin (Hardcode84)

Changes

Treat integer range for vector type as union of ranges of individual elements. With this semantics, most arith ops on vectors will work out of the box, the only special handling needed for constants and vector elements manipulation ops.

The end goal of these changes is to be able to optimize vectorized index calculations.


Full diff: https://github.com/llvm/llvm-project/pull/112292.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+7-3)
  • (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp (+19-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+12-6)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+15)
  • (modified) mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir (+1-1)
  • (added) mlir/test/Dialect/Vector/int-range-interface.mlir (+57)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+3-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+1-1)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b0de7c11b9d436..d890d5017daca7 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -13,20 +13,21 @@
 #ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
 #define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
 
-include "mlir/Dialect/Vector/IR/Vector.td"
-include "mlir/Dialect/Vector/IR/VectorAttributes.td"
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
 include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
-include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/Vector/IR/Vector.td"
+include "mlir/Dialect/Vector/IR/VectorAttributes.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/BuiltinAttributes.td"
+include "mlir/IR/EnumAttr.td"
 
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
@@ -627,6 +628,7 @@ def Vector_DeinterleaveOp :
 
 def Vector_ExtractElementOp :
   Vector_Op<"extractelement", [Pure,
+     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
      TypesMatchWith<"result type matches element type of vector operand",
                     "vector", "result",
                     "::llvm::cast<VectorType>($_self).getElementType()">]>,
@@ -673,6 +675,7 @@ def Vector_ExtractElementOp :
 
 def Vector_ExtractOp :
   Vector_Op<"extract", [Pure,
+     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      InferTypeOpAdaptorWithIsCompatible]> {
@@ -2795,6 +2798,7 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
 
 def Vector_SplatOp : Vector_Op<"splat", [
     Pure,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
     TypesMatchWith<"operand type matches element type of result",
                    "aggregate", "input",
                    "::llvm::cast<VectorType>($_self).getElementType()">
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index bf9eabbedc3a1f..a97e43708d9a37 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -53,9 +54,10 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
     dialect = parent->getDialect();
   else
     dialect = value.getParentBlock()->getParentOp()->getDialect();
+
+  Type type = getElementTypeOrSelf(value);
   solver->propagateIfChanged(
-      cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
-                                 dialect)));
+      cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
 }
 
 LogicalResult IntegerRangeAnalysis::visitOperation(
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 462044417b5fb8..3df483a4d2ddd0 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -35,10 +35,27 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
 
 void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                           SetIntRangeFn setResultRange) {
-  auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
-  if (constAttr) {
+  if (auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
     const APInt &value = constAttr.getValue();
     setResultRange(getResult(), ConstantIntRanges::constant(value));
+    return;
+  }
+  if (auto constAttr =
+          llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
+    std::optional<ConstantIntRanges> result;
+    for (APInt &&val : constAttr) {
+      auto range = ConstantIntRanges::constant(val);
+      if (!result) {
+        result = range;
+      } else {
+        result = result->rangeUnion(range);
+      }
+    }
+
+    if (result)
+      setResultRange(getResult(), *result);
+
+    return;
   }
 }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 521138c1f6f4cd..d494bba081f801 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -51,21 +51,27 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
   if (!maybeConstValue.has_value())
     return failure();
 
+  Type type = value.getType();
+  Location loc = value.getLoc();
   Operation *maybeDefiningOp = value.getDefiningOp();
   Dialect *valueDialect =
       maybeDefiningOp ? maybeDefiningOp->getDialect()
                       : value.getParentRegion()->getParentOp()->getDialect();
-  Attribute constAttr =
-      rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
-  Operation *constOp = valueDialect->materializeConstant(
-      rewriter, constAttr, value.getType(), value.getLoc());
+
+  Attribute constAttr;
+  if (auto shaped = dyn_cast<ShapedType>(type)) {
+    constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
+  } else {
+    constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
+  }
+  Operation *constOp =
+      valueDialect->materializeConstant(rewriter, constAttr, type, loc);
   // Fall back to arith.constant if the dialect materializer doesn't know what
   // to do with an integer constant.
   if (!constOp)
     constOp = rewriter.getContext()
                   ->getLoadedDialect<ArithDialect>()
-                  ->materializeConstant(rewriter, constAttr, value.getType(),
-                                        value.getLoc());
+                  ->materializeConstant(rewriter, constAttr, type, loc);
   if (!constOp)
     return failure();
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..43920cb5cf30d3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1221,6 +1221,11 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
 
+void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                         SetIntRangeFn setResultRanges) {
+  setResultRanges(getResult(), argRanges.front());
+}
+
 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
                                      Value source) {
   result.addOperands({source});
@@ -1273,6 +1278,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
 // ExtractOp
 //===----------------------------------------------------------------------===//
 
+void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                  SetIntRangeFn setResultRanges) {
+  setResultRanges(getResult(), argRanges.front());
+}
+
 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                               Value source, int64_t position) {
   build(builder, result, source, ArrayRef<int64_t>{position});
@@ -6423,6 +6433,11 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
   return SplatElementsAttr::get(getType(), {constOperand});
 }
 
+void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                SetIntRangeFn setResultRanges) {
+  setResultRanges(getResult(), argRanges.front());
+}
+
 //===----------------------------------------------------------------------===//
 // StepOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
index 49bd74cfe9124a..9f3d575838f320 100644
--- a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
+++ b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
@@ -100,7 +100,7 @@ func.func @dead_code() {
 // Make sure not crash.
 // CHECK-LABEL: @no_integer_or_index
 func.func @no_integer_or_index() { 
-  // CHECK: arith.cmpi
+  // CHECK: arith.constant dense<false> : vector<1xi1>
   %cst_0 = arith.constant dense<[0]> : vector<1xi32> 
   %cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32> 
   return
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
new file mode 100644
index 00000000000000..0fac3c417b7bad
--- /dev/null
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
+
+
+// CHECK-LABEL: func @constant_vec
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+func.func @constant_vec() -> vector<8xindex> {
+  %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+  %1 = test.reflect_bounds %0 : vector<8xindex>
+  func.return %1 : vector<8xindex>
+}
+
+// CHECK-LABEL: func @constant_splat
+// CHECK: test.reflect_bounds {smax = 3 : si32, smin = 3 : si32, umax = 3 : ui32, umin = 3 : ui32}
+func.func @constant_splat() -> vector<8xi32> {
+  %0 = arith.constant dense<3> : vector<8xi32>
+  %1 = test.reflect_bounds %0 : vector<8xi32>
+  func.return %1 : vector<8xi32>
+}
+
+
+// CHECK-LABEL: func @vector_splat
+// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
+func.func @vector_splat() -> vector<4xindex> {
+  %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+  %1 = vector.splat %0 : vector<4xindex>
+  %2 = test.reflect_bounds %1 : vector<4xindex>
+  func.return %2 : vector<4xindex>
+}
+
+// CHECK-LABEL: func @vector_extract
+// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
+func.func @vector_extract() -> index {
+  %0 = test.with_bounds { umin = 5 : index, umax = 6 : index, smin = 5 : index, smax = 6 : index } : vector<4xindex>
+  %1 = vector.extract %0[0] : index from vector<4xindex>
+  %2 = test.reflect_bounds %1 : index
+  func.return %2 : index
+}
+
+// CHECK-LABEL: func @vector_extractelement
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index}
+func.func @vector_extractelement() -> index {
+  %c0 = arith.constant 0 : index
+  %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
+  %1 = vector.extractelement %0[%c0 : index] : vector<4xindex>
+  %2 = test.reflect_bounds %1 : index
+  func.return %2 : index
+}
+
+// CHECK-LABEL: func @vector_add
+// CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index}
+func.func @vector_add() -> vector<4xindex> {
+  %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
+  %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
+  %2 = arith.addi %0, %1 : vector<4xindex>
+  %3 = test.reflect_bounds %2 : vector<4xindex>
+  func.return %3 : vector<4xindex>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 69091fb893fad6..b268e549b93ab6 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -760,12 +760,13 @@ void TestReflectBoundsOp::inferResultRanges(
   Type sIntTy, uIntTy;
   // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
   // Types for the Attributes.
-  if (auto intTy = llvm::dyn_cast<IntegerType>(getType())) {
+  Type type = getElementTypeOrSelf(getType());
+  if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
     unsigned bitwidth = intTy.getWidth();
     sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
     uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
   } else
-    sIntTy = uIntTy = getType();
+    sIntTy = uIntTy = type;
 
   setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
   setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9e19966414d1d7..301f55c670d752 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2781,7 +2781,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
 //===----------------------------------------------------------------------===//
 // Test InferIntRangeInterface
 //===----------------------------------------------------------------------===//
-def InferIntRangeType : AnyTypeOf<[AnyInteger, Index]>;
+def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOf<[AnyInteger, Index]>]>;
 
 def TestWithBoundsOp : TEST_Op<"with_bounds",
                           [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The end goal of these changes is to be able to optimize vectorized index calculations.

Would you mind sharing a bit more? I'm intrigued. And in particular, how does this compare to ValueBoundsAnalysis?

// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
func.func @constant_vec() -> vector<8xindex> {
%0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
%1 = test.reflect_bounds %0 : vector<8xindex>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know what these test Ops do and I couldn't find any documentation in code. Could add some docs, pls?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test ops are from existing integer range inference tests - they have an implement of the integer range inference interface that sets attributes on reflect_bounds to match the bounds of the input

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I gathered that much from mlir/test/lib/Dialect/Test/TestOps.td, but it doesn’t quite clarify things for me. The lack of documentation for these operations makes it hard to understand the distinction between test.reflect_bounds and test.with_bounds.

Given that @Hardcode84 is already using these ops for testing, it would be fantastic if some of that expertise could be shared through documentation. This would benefit everyone working with these tests!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some description to the ops

@Hardcode84
Copy link
Contributor Author

Hardcode84 commented Oct 15, 2024

The end goal of these changes is to be able to optimize vectorized index calculations.

Would you mind sharing a bit more? I'm intrigued. And in particular, how does this compare to ValueBoundsAnalysis?

IntegerRangeAnalysis propagates integer range info through the program using dataflow framework.

ValueBoundsOpInterface was added much later IIRC, and currently completely isolated from IntegerRangeAnalysis.
Ideally, they should be unified at some point, but it's out of scope currently.

In my specific use case, I have code like this:

          %178 = arith.muli %arg3, %c32 : index
          %179 = arith.addi %178, %8 : index
          %180 = vector.splat %179 : vector<8xindex>
          %181 = arith.addi %180, %cst_8 : vector<8xindex>
          %182 = arith.remsi %181, %cst_7 : vector<8xindex>
          %183 = arith.muli %182, %cst_6 : vector<8xindex>
          %184 = arith.remsi %181, %cst_5 : vector<8xindex>
          %185 = arith.divsi %184, %cst_7 : vector<8xindex>
          %186 = arith.muli %185, %cst_4 : vector<8xindex>
          %187 = arith.muli %arg3, %c288 : index
          %188 = arith.addi %187, %19 : index
          %189 = arith.muli %188, %c9 : index
          %190 = vector.splat %189 : vector<8xindex>
          %191 = arith.addi %190, %cst_3 : vector<8xindex>
          %192 = arith.divsi %191, %cst_2 : vector<8xindex>
          %193 = arith.addi %192, %20 : vector<8xindex>
          %194 = arith.addi %193, %21 : vector<8xindex>
          %195 = arith.addi %194, %186 : vector<8xindex>
          %196 = arith.addi %195, %183 : vector<8xindex>
          %197 = arith.addi %196, %22 : vector<8xindex>
          %198 = vector.gather %0[%c0, %c0, %c0, %c0] [%197], %23, %cst_1 : memref<2x66x66x640xf16, strided<[2787840, 42240, 640, 1], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf16> into vector<8xf16>

And a lot of these index calculations can be potentially truncated to i32.

@banach-space
Copy link
Contributor

banach-space commented Oct 16, 2024

Ideally, they should be unified at some point, but it's out of scope currently.

+1 to unification and for it being out of scope.

And a lot of these index calculations can be potentially truncated to i32.

Oh, that would be super handy for us as well, thank you for working on this! Have you already prototyped the remaining parts? I'd be keen to see the full flow :)

EDIT I've just realised that this is connected to #112404

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... while we're here:

  1. insert[element] unions in the range of the inserted element
  2. shape_cast doesn't disturb ranges
  3. There might be some other operations that have trivial implementations here - shufflevector-likes, for example

// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
func.func @constant_vec() -> vector<8xindex> {
%0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
%1 = test.reflect_bounds %0 : vector<8xindex>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test ops are from existing integer range inference tests - they have an implement of the integer range inference interface that sets attributes on reflect_bounds to match the bounds of the input

@Hardcode84
Copy link
Contributor Author

... while we're here:

1. `insert[element]` unions in the range of the inserted element

2. `shape_cast` doesn't disturb ranges

3. There might be some other operations that have trivial implementations here - `shufflevector`-likes, for example

Added a few more ops, but I was planning to add more ops support on as-needed manner

Comment on lines +20 to +30
include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/EnumAttr.td"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look like unrelated changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added mlir/Interfaces/InferIntRangeInterface.td and sorted rests of the includes.

if (auto constAttr =
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
std::optional<ConstantIntRanges> result;
for (APInt &&val : constAttr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why rvalue ref and not plain ref?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}
}

if (result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what cases would result be empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never trigger as 0-element vectors are not suported. Changed to assert.

// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
func.func @constant_vec() -> vector<8xindex> {
%0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
%1 = test.reflect_bounds %0 : vector<8xindex>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I gathered that much from mlir/test/lib/Dialect/Test/TestOps.td, but it doesn’t quite clarify things for me. The lack of documentation for these operations makes it hard to understand the distinction between test.reflect_bounds and test.with_bounds.

Given that @Hardcode84 is already using these ops for testing, it would be fantastic if some of that expertise could be shared through documentation. This would benefit everyone working with these tests!

@banach-space
Copy link
Contributor

I was planning to add more ops support on as-needed manner

That would make more sense to me. Adding code that we have no use for is not different to adding dead code and maintenance burden (e.g. more tests to maintain).

@krzysz00
Copy link
Contributor

To the question: test.with_bounds summons a value from the void with a given set of minima and maxima. test.reflect_bounds is used to make analysis results visible to FileCheck

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One request, but approved otherwise

(and it's making me realize something needs to be done about the possibility of 32-bit index in a more fundamental way (ex. teaching rangeUnion() to check the truncated version) but that's not this PR. Would you mind taking a bit to chat?)

%2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex>
%3 = test.reflect_bounds %2 : vector<4xindex>
func.return %3 : vector<4xindex>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, just because I can't remember exactly rangeUnion() does on un-annotated values, could I get a test that goes something like

func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
  %c0 = arith.constant 0 : index 
  %v = vector.load %memref[%c0] : vector<4xi32>
  %e = vector.extract %v[0]
  %bounds = test.reflect_bounds %e : i32
  func.return %bounds : i32
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Treat integer range for vector type as union of ranges of individual elements.
With this semantics, most arith ops on vectors will work out of the box, the only special handling needed for constants and vector elements manipulation ops.

The end goal of these changes is to optimize vectorized index calculations.
@Hardcode84 Hardcode84 merged commit f54cdc5 into llvm:main Nov 1, 2024
8 checks passed
@Hardcode84 Hardcode84 deleted the vector-range branch November 1, 2024 20:58
@banach-space
Copy link
Contributor

banach-space commented Nov 2, 2024

This change looks good to me, and I appreciate the contribution—thank you!

I noticed this was merged just a few hours after approval by one of the two active reviewers. In cases like this, I like to quote LLVM's Code Review policy (emphasis added):

If approval is received very quickly, a patch author may also elect to wait before committing (and this is certainly considered polite for non-trivial patches). Especially given the global nature of our community, this waiting time should be at least 24 hours. Please also be mindful of weekends and major holidays.

While all rules were followed here, in cases where a patch has been in review for multiple weeks, I would certainly appreciate authors electing to wait that 24-hour period. When a review has taken this long, it often signals that there’s no immediate urgency to finalize the change.

Kind request for @krzysz00: when giving an unqualified LGTM, could you add a quick note, like "LGTM, I believe that all comments from reviewers have been addressed"? This would be helpful in ensuring everyone’s feedback is fully considered before merging. Thanks!

P.S. I am trying to add some clarifications to our coding review docs, but things are moving slowly: #111735. FYI, so this doesn't look like me moaning on random PRs.

smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
Treat integer range for vector type as union of ranges of individual
elements. With this semantics, most arith ops on vectors will work out
of the box, the only special handling needed for constants and vector
elements manipulation ops.

The end goal of these changes is to be able to optimize vectorized index
calculations.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
Treat integer range for vector type as union of ranges of individual
elements. With this semantics, most arith ops on vectors will work out
of the box, the only special handling needed for constants and vector
elements manipulation ops.

The end goal of these changes is to be able to optimize vectorized index
calculations.
krzysz00 added a commit to krzysz00/llvm-project that referenced this pull request Nov 5, 2024
PR llvm#112292 added support for vectors to the integer range inference
interface and analysis, but didn't update the getDestWidth() method.
This caused crashes when trying to infer the ranges of `arith.extsi`
with vector inputs, as the code would try to sign-extend a N-bit value
to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().
krzysz00 added a commit that referenced this pull request Nov 5, 2024
…114898)

PR #112292 added support for vectors to the integer range inference
interface and analysis, but didn't update the getDestWidth() method.
This caused crashes when trying to infer the ranges of `arith.extsi`
with vector inputs, as the code would try to sign-extend a N-bit value
to a 0-bit one, which would assert and crash.

This commit fixes the issue by adding a getElementTypeOrSelf().
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants