-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[TOSA] bug fix infer shape for slice #108306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check: - size = -1 - size is out of bound - start is out of bound Signed-off-by: Tai Ly <[email protected]> Change-Id: I8b59502a93cb332fe5c9a9f87970b83742538126
@llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check:
Full diff: https://github.com/llvm/llvm-project/pull/108306.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..4ca42cc99a507a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -842,8 +842,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- inferredReturnShapes.push_back(
- ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
+ auto start = adaptor.getStart();
+ auto size = adaptor.getSize();
+
+ // if size[i] is -1, all remaining elements in dimension i are included
+ // in the slice, similar to TF.
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ // initialize outputShape to all unknown
+ SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
+ if (inputShape.hasRank()) {
+ for (size_t i = 0; i < size.size(); i++) {
+ if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
+ (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
+ start[i] < inputShape.getDimSize(i))) {
+ // size[i] is not 0 and not < -1, and start[i] is in valid range
+ if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
+ // input shape has unknown dim[i] - only valid if size[i] > 0
+ if (size[i] > 0) {
+ outputShape[i] = size[i];
+ }
+ } else {
+ // input shape has known dim[i]
+ if (size[i] == -1) {
+ outputShape[i] = inputShape.getDimSize(i) - start[i];
+ } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
+ // start[i] + size[i] is within bound of input shape's dim[i]
+ outputShape[i] = size[i];
+ }
+ }
+ }
+ }
+ } else {
+ outputShape = convertToMlirShape(size);
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..d2314698afa925 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_slice_size_minus_one
+func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+ // this checks following
+ // dim 0: size=-1, input dim=? => inferred output dim is ?
+ // dim 1: size=-1 => inferred output dim is input_dim - start
+ // dim 2: size=-1, start=-1 => inferred output dim is ?
+ // dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
+ %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_size_out_of_bound
+func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: size=0 => inferred output dim is ?
+ // dim 1: size=-2 => inferred output dim is ?
+ // dim 3: start+size out of bound because size too big: inferred output dim is ?
+ // dim 4: size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_start_out_of_bound
+func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: start=-1 => inferred output dim is ?
+ // dim 1: start=8 => inferred output dim is ?
+ // dim 2: start+size out of bound: inferred output dim is ?
+ // dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_slice_dynamic
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
// CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>
|
@llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check:
Full diff: https://github.com/llvm/llvm-project/pull/108306.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..4ca42cc99a507a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -842,8 +842,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- inferredReturnShapes.push_back(
- ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
+ auto start = adaptor.getStart();
+ auto size = adaptor.getSize();
+
+ // if size[i] is -1, all remaining elements in dimension i are included
+ // in the slice, similar to TF.
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ // initialize outputShape to all unknown
+ SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
+ if (inputShape.hasRank()) {
+ for (size_t i = 0; i < size.size(); i++) {
+ if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
+ (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
+ start[i] < inputShape.getDimSize(i))) {
+ // size[i] is not 0 and not < -1, and start[i] is in valid range
+ if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
+ // input shape has unknown dim[i] - only valid if size[i] > 0
+ if (size[i] > 0) {
+ outputShape[i] = size[i];
+ }
+ } else {
+ // input shape has known dim[i]
+ if (size[i] == -1) {
+ outputShape[i] = inputShape.getDimSize(i) - start[i];
+ } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
+ // start[i] + size[i] is within bound of input shape's dim[i]
+ outputShape[i] = size[i];
+ }
+ }
+ }
+ }
+ } else {
+ outputShape = convertToMlirShape(size);
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..d2314698afa925 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_slice_size_minus_one
+func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+ // this checks following
+ // dim 0: size=-1, input dim=? => inferred output dim is ?
+ // dim 1: size=-1 => inferred output dim is input_dim - start
+ // dim 2: size=-1, start=-1 => inferred output dim is ?
+ // dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
+ %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_size_out_of_bound
+func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: size=0 => inferred output dim is ?
+ // dim 1: size=-2 => inferred output dim is ?
+ // dim 3: start+size out of bound because size too big: inferred output dim is ?
+ // dim 4: size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_start_out_of_bound
+func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: start=-1 => inferred output dim is ?
+ // dim 1: start=8 => inferred output dim is ?
+ // dim 2: start+size out of bound: inferred output dim is ?
+ // dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_slice_dynamic
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
// CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/53/builds/6665 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/12586 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/8103 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/10880 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/5352 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/117/builds/3056 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/177/builds/7045 Here is the relevant piece of the build log for the reference
|
This reverts commit 3b9526b.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/89/builds/8964 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/5341 Here is the relevant piece of the build log for the reference
|
@Tai78641 I merged this but it caused issues and had to revert. It needs a rebase. Apologies for any inconvenience. |
This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1
added tests to check: