Skip to content

Revert "[TOSA] bug fix infer shape for slice" #113413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2024

Conversation

GeorgeARM
Copy link
Contributor

Reverts #108306

@llvmbot
Copy link
Member

llvmbot commented Oct 23, 2024

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Georgios Pinitas (GeorgeARM)

Changes

Reverts llvm/llvm-project#108306


Full diff: https://github.com/llvm/llvm-project/pull/113413.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+2-34)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (-42)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 01312584652049..631d3c48f2df02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -844,40 +844,8 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     SliceOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  auto start = adaptor.getStart();
-  auto size = adaptor.getSize();
-
-  // if size[i] is -1, all remaining elements in dimension i are included
-  // in the slice, similar to TF.
-  ShapeAdaptor inputShape(adaptor.getInput().getType());
-  // initialize outputShape to all unknown
-  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
-  if (inputShape.hasRank()) {
-    for (size_t i = 0; i < size.size(); i++) {
-      if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
-          (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
-           start[i] < inputShape.getDimSize(i))) {
-        // size[i] is not 0 and not < -1, and start[i] is in valid range
-        if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
-          // input shape has unknown dim[i] - only valid if size[i] > 0
-          if (size[i] > 0) {
-            outputShape[i] = size[i];
-          }
-        } else {
-          // input shape has known dim[i]
-          if (size[i] == -1) {
-            outputShape[i] = inputShape.getDimSize(i) - start[i];
-          } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
-            // start[i] + size[i] is within bound of input shape's dim[i]
-            outputShape[i] = size[i];
-          }
-        }
-      }
-    }
-  } else {
-    outputShape = convertToMlirShape(size);
-  }
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(
+      ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d2314698afa925..d46de740800e93 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,48 +532,6 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
 
 // -----
 
-// CHECK-LABEL: @test_slice_size_minus_one
-func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
-  // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
-  // this checks following
-  //  dim 0: size=-1, input dim=? => inferred output dim is ?
-  //  dim 1: size=-1 => inferred output dim is input_dim - start
-  //  dim 2: size=-1, start=-1 => inferred output dim is ?
-  //  dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
-  %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_slice_size_out_of_bound
-func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
-  // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
-  // this checks following
-  //  dim 0: size=0 => inferred output dim is ?
-  //  dim 1: size=-2 => inferred output dim is ?
-  //  dim 3: start+size out of bound because size too big: inferred output dim is ?
-  //  dim 4: size=4, input dim=? => inferred output dim is 4
-  %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @test_slice_start_out_of_bound
-func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
-  // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
-  // this checks following
-  //  dim 0: start=-1 => inferred output dim is ?
-  //  dim 1: start=8 => inferred output dim is ?
-  //  dim 2: start+size out of bound: inferred output dim is ?
-  //  dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
-  %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
-  return
-}
-
-// -----
-
 // CHECK-LABEL: @test_slice_dynamic
 func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
   // CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>

@GeorgeARM GeorgeARM merged commit 8ad8db9 into main Oct 23, 2024
11 checks passed
@GeorgeARM GeorgeARM deleted the revert-108306-pr_fix_slice_infer_shape branch October 23, 2024 03:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants