Skip to content

[mlir][tosa] Change the start and size of slice to tosa shape type #124209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2025

Conversation

Jerry-Ge
Copy link
Member

Update to use getConstShapeValue to collect shape information along the graph.

Change-Id: Ic6fc2341e3bcfbec06a1d08986e26dd08573bd9c

@llvmbot
Copy link
Member

llvmbot commented Jan 23, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Jerry-Ge (Jerry-Ge)

Changes

Update to use getConstShapeValue to collect shape information along the graph.

Change-Id: Ic6fc2341e3bcfbec06a1d08986e26dd08573bd9c


Patch is 38.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124209.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-2)
  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+30-4)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+31-12)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+18-4)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+2-2)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir (+4-2)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+7-3)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+48-23)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+7-2)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+6-2)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+8-5)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+13-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+12-6)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+30-10)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2953e006bbe8d1..d2ac206263bb13 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1664,8 +1664,8 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    DenseI64ArrayAttr:$start,
-    DenseI64ArrayAttr:$size
+    Tosa_Shape:$start,
+    Tosa_Shape:$size
   );
 
   let results = (outs
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5aa0269a675cbe..d6319c1c3fd891 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -268,12 +268,28 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
     ShapedType resultType = cast<ShapedType>(sliceOp.getType());
     if (llvm::isa<UnrankedTensorType>(resultType))
       return failure();
+
+    ElementsAttr StartElems;
+    ElementsAttr SizeElems;
+
+    if (!matchPattern(sliceOp.getStart(), m_Constant(&StartElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "start of slice must be a static ranked shape");
+
+    if (!matchPattern(sliceOp.getSize(), m_Constant(&SizeElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "size of slice must be a static ranked shape");
+
+    llvm::SmallVector<int64_t> SliceStarts =
+        llvm::to_vector(StartElems.getValues<int64_t>());
+    llvm::SmallVector<int64_t> SliceSizes =
+        llvm::to_vector(SizeElems.getValues<int64_t>());
+
     SmallVector<int64_t> strides, sizes;
-    ArrayRef<int64_t> starts = sliceOp.getStart();
     strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
 
     SmallVector<Value> dynSizes;
-    for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
+    for (const auto &i : llvm::enumerate(SliceSizes)) {
       int64_t size = i.value();
       size_t index = i.index();
       sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
@@ -282,17 +298,27 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
 
       auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
       auto offset = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getIndexAttr(starts[index]));
+          loc, rewriter.getIndexAttr(SliceStarts[index]));
       dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
     }
 
     auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
         sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
-        ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
+        ValueRange({}), rewriter.getDenseI64ArrayAttr(SliceStarts),
         rewriter.getDenseI64ArrayAttr(sizes),
         rewriter.getDenseI64ArrayAttr(strides));
 
     rewriter.replaceOp(sliceOp, newSliceOp.getResult());
+
+    auto removeIfRedundant = [&](Operation *op) {
+      if (op->getResults().size() == 1 && op->getResult(0).hasOneUse())
+        rewriter.eraseOp(op);
+    };
+
+    // Remove const_shape ops when it no longer has use point.
+    removeIfRedundant(sliceOp.getStart().getDefiningOp());
+    removeIfRedundant(sliceOp.getSize().getDefiningOp());
+
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f7a596f1ccb192..7f3b929d276024 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -393,8 +393,21 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
           sliceOp, "slice input must be a static ranked tensor");
     int32_t axis = concatOp.getAxis();
 
-    llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
-    llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
+    DenseElementsAttr StartElems;
+    DenseElementsAttr SizeElems;
+
+    if (!matchPattern(sliceOp.getStart(), m_Constant(&StartElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "start of slice must be a static ranked shape");
+
+    if (!matchPattern(sliceOp.getSize(), m_Constant(&SizeElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "size of slice must be a static ranked shape");
+
+    llvm::SmallVector<int64_t> SliceStarts =
+        llvm::to_vector(StartElems.getValues<int64_t>());
+    llvm::SmallVector<int64_t> SliceSizes =
+        llvm::to_vector(SizeElems.getValues<int64_t>());
 
     // Validate slice on the concatenated axis. Slicing along this
     // axis should span only one of the inputs to the concatenate
@@ -406,17 +419,19 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
         return rewriter.notifyMatchFailure(
             sliceOp, "concat input must be a static ranked tensor");
 
-      if (sliceStart[axis] >= 0 &&
-          (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
-        replaceWithSlice = rewriter
-                               .create<tosa::SliceOp>(
-                                   sliceOp.getLoc(), sliceOp.getType(), input,
-                                   rewriter.getDenseI64ArrayAttr(sliceStart),
-                                   rewriter.getDenseI64ArrayAttr(sliceSize))
-                               .getResult();
+      if (SliceStarts[axis] >= 0 &&
+          (SliceStarts[axis] + SliceSizes[axis]) <= inputType.getDimSize(axis)) {
+        auto start_op =
+            getTosaConstShape(rewriter, sliceOp.getLoc(), SliceStarts);
+        auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), SliceSizes);
+        replaceWithSlice =
+            rewriter
+                .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
+                                       input, start_op, size_op)
+                .getResult();
         break;
       }
-      sliceStart[axis] -= inputType.getDimSize(axis);
+      SliceStarts[axis] -= inputType.getDimSize(axis);
     }
 
     if (!replaceWithSlice)
@@ -963,7 +978,11 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
 
   if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
       outputTy.getNumElements() == 1) {
-    llvm::SmallVector<uint64_t> indices(getStart());
+    DenseElementsAttr StartElems;
+    if (!matchPattern(getStart(), m_Constant(&StartElems)))
+      return {};
+
+    llvm::SmallVector<uint64_t> indices = llvm::to_vector(StartElems.getValues<uint64_t>());
     auto value = operand.getValues<Attribute>()[indices];
     return SplatElementsAttr::get(outputTy, value);
   }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fdccce60fe1d86..2f7bd75974b79b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -891,8 +891,18 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     SliceOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  auto start = adaptor.getStart();
-  auto size = adaptor.getSize();
+
+  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
+  SmallVector<int64_t> start;
+  SmallVector<int64_t> size;
+
+  if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
+      !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
+    auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
+    SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+    inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+    return success();
+  }
 
   // if size[i] is -1, all remaining elements in dimension i are included
   // in the slice, similar to TF.
@@ -933,11 +943,15 @@ LogicalResult tosa::SliceOp::verify() {
   if (!inputType)
     return success();
 
-  if (static_cast<size_t>(inputType.getRank()) != getStart().size())
+  auto StartShapeRank =
+      llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
+  if (inputType.getRank() != StartShapeRank)
     return emitOpError(
         "length of start attribute is not equal rank of input shape");
 
-  if (static_cast<size_t>(inputType.getRank()) != getSize().size())
+  auto SizeShapeRank =
+      llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
+  if (inputType.getRank() != SizeShapeRank)
     return emitOpError(
         "length of size attribute is not equal rank of input shape");
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 1b97f0b245d9ba..807f9cd683bb8c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -302,8 +302,8 @@ class TransposeConvStridedConverter
 
     auto slice = CreateOpAndInferShape<tosa::SliceOp>(
                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
-                     rewriter.getDenseI64ArrayAttr(sliceBegin),
-                     rewriter.getDenseI64ArrayAttr(sliceSize))
+                     getTosaConstShape(rewriter, loc, sliceBegin),
+                     getTosaConstShape(rewriter, loc, sliceSize))
                      .getResult();
 
     llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir
index 36eb4d4669b07a..a72d6b333f7ea4 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir
@@ -2,7 +2,9 @@
 
 // CHECK-LABEL:  @slice_resultType_unranked
 func.func @slice_resultType_unranked(%arg0: tensor<?xf32>) -> (tensor<*xf32>) {
+  %0 = tosa.const_shape  {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape  {value = dense<0> : tensor<1xindex>} : () -> !tosa.shape<1>
   // expected-error@+1 {{failed to legalize operation 'tosa.slice'}}
-  %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 0>} : (tensor<?xf32>)  -> (tensor<*xf32>)
-  return %0 : tensor<*xf32>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<?xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<*xf32>
+  return %2 : tensor<*xf32>
 }
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 2f11b31aad2307..dd97a872a07b46 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -437,7 +437,9 @@ func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3x
 // CHECK-LABEL: func @slice
 func.func @slice(%arg0: tensor<6xf32>) ->() {
   // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
-  %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 1>} : (tensor<6xf32>)  -> (tensor<1xf32>)
+  %0 = tosa.const_shape  {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape  {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<6xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<1xf32>
   return
 }
 
@@ -450,8 +452,10 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
   // CHECK: %[[C2:.+]] = arith.constant 2 : index
   // CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]]
   // CHECK: tensor.extract_slice %arg0[2] [%[[SUB]]] [1]
-  %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: -1>} : (tensor<?xf32>)  -> (tensor<?xf32>)
-  return %0 : tensor<?xf32>
+  %0 = tosa.const_shape  {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %1 = tosa.const_shape  {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<?xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<?xf32>
+  return %2 : tensor<?xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e394188e9a9311..4a8694c4713e4e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -572,18 +572,22 @@ func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3x!quant.uniform<
 
 // CHECK-LABEL: @slice_fold
 func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
+  %0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
   // CHECK: return %arg0
-  %0 = tosa.slice %arg0 { size = array<i64: 3, 4>, start = array<i64: 0, 0>}: (tensor<3x4xf32>) -> tensor<3x4xf32>
-  return %0 : tensor<3x4xf32>
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<3x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x4xf32>
+  return %3 : tensor<3x4xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @slice_nofold
 func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+  %0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
   // CHECK: tosa.slice
-  %0 = tosa.slice %arg0 { size = array<i64: 3, 4>, start = array<i64: 0, 0>}: (tensor<?x4xf32>) -> tensor<?x4xf32>
-  return %0 : tensor<?x4xf32>
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+  return %3 : tensor<?x4xf32>
 }
 
 // -----
@@ -663,9 +667,12 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
 // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
 func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) {
   %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32>
-  %1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
-  %2 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
-  return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
+  %1 = tosa.const_shape {value = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %2 = tosa.const_shape {value = dense<[0, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %3 = tosa.const_shape {value = dense<[1, 12, 12, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %4 = tosa.slice %0, %1, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32>
+  %5 = tosa.slice %0, %2, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32>
+  return %4, %5 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
 }
 
 // -----
@@ -675,38 +682,56 @@ func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %
 // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32>
 func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) {
   %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32>
-  %1 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
-  %2 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 12, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
-  return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
+  %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.const_shape {value = dense<[0, 12, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %3 = tosa.const_shape {value = dense<[1, 12, 12]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %4 = tosa.slice %0, %1, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32>
+  %5 = tosa.slice %0, %2, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32>
+  return %4, %5 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @canonicalize_cross_concat_inputs
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
-// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
-// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
-// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape  {value = dense<[1, 12, 20]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape  {value = dense<[1, 12, 15]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape  {value = dense<[0, 0, 4]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape  {value = dense<0> : tensor<3xindex>}
+// CHECK: %[[VAL_6:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_6]], %[[VAL_5]], %[[VAL_3]]
+// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_6]], %[[VAL_4]], %[[VAL_2]]
+// CHECK: return %[[VAL_7]], %[[VAL_8]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
 func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) {
   %0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
-  %1 = tosa.slice %0 {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
-  %2 = tosa.slice %0 {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
-  return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+  %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %2 = tosa.const_shape {value = dense<[0, 0, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %3 = tosa.const_shape {value = dense<[1, 12, 15]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %4 = tosa.const_shape {value = dense<[1, 12, 20]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %5 = tosa.slice %0, %1, %3 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x15xf32>
+  %6 = tosa.slice %0, %2, %4 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x20xf32>
+  return %5, %6 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.slice %[[VAL_0]] {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
-// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 0>} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
-// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
+// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape  {value = dense<[1, 3, 0]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape  {value = dense<[1, 3, 12]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape  {value = dense<0> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape  {value = dense<[1, 6, 12]> : tensor<3xindex>}
+// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_0]], %[[VAL_4]], %[[VAL_5]]
+// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]
+// CHECK: return %[[VAL_6]], %[[VAL_7]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
 func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
   %0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf3...
[truncated]

Copy link

github-actions bot commented Jan 23, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Update to use getConstShapeValue to collect shape information
along the graph.

Change-Id: Ic6fc2341e3bcfbec06a1d08986e26dd08573bd9c
@Jerry-Ge Jerry-Ge merged commit 956c070 into llvm:main Jan 29, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants