diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 2186510e7db1e..dc4497dc971d8 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1671,8 +1671,8 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> { let arguments = (ins Tosa_Tensor:$input1, - DenseI64ArrayAttr:$start, - DenseI64ArrayAttr:$size + Tosa_Shape:$start, + Tosa_Shape:$size ); let results = (outs diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 5aa0269a675cb..c4b787d5c865b 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -268,12 +268,28 @@ class SliceConverter : public OpConversionPattern { ShapedType resultType = cast(sliceOp.getType()); if (llvm::isa(resultType)) return failure(); + + ElementsAttr startElems; + ElementsAttr sizeElems; + + if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems))) + return rewriter.notifyMatchFailure( + sliceOp, "start of slice must be a static ranked shape"); + + if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) + return rewriter.notifyMatchFailure( + sliceOp, "size of slice must be a static ranked shape"); + + llvm::SmallVector sliceStarts = + llvm::to_vector(startElems.getValues()); + llvm::SmallVector sliceSizes = + llvm::to_vector(sizeElems.getValues()); + SmallVector strides, sizes; - ArrayRef starts = sliceOp.getStart(); strides.resize(cast(sliceOp.getType()).getRank(), 1); SmallVector dynSizes; - for (const auto &i : llvm::enumerate(sliceOp.getSize())) { + for (const auto &i : llvm::enumerate(sliceSizes)) { int64_t size = i.value(); size_t index = i.index(); sizes.push_back(size == -1 ? ShapedType::kDynamic : size); @@ -282,17 +298,27 @@ class SliceConverter : public OpConversionPattern { auto dim = rewriter.create(loc, input, index); auto offset = rewriter.create( - loc, rewriter.getIndexAttr(starts[index])); + loc, rewriter.getIndexAttr(sliceStarts[index])); dynSizes.push_back(rewriter.create(loc, dim, offset)); } auto newSliceOp = rewriter.create( sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, - ValueRange({}), rewriter.getDenseI64ArrayAttr(starts), + ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), rewriter.getDenseI64ArrayAttr(sizes), rewriter.getDenseI64ArrayAttr(strides)); rewriter.replaceOp(sliceOp, newSliceOp.getResult()); + + // Remove const_shape ops when it no longer has use point. + Operation *startConstShape = sliceOp.getStart().getDefiningOp(); + if (startConstShape->getResult(0).hasOneUse()) + rewriter.eraseOp(startConstShape); + + Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); + if (sizeConstShape->getResult(0).hasOneUse()) + rewriter.eraseOp(sizeConstShape); + return success(); } }; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index b8e0005dc1bc0..6923f0a5ca4d9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -444,8 +444,21 @@ struct ConcatSliceOptimization : public OpRewritePattern { sliceOp, "slice input must be a static ranked tensor"); int32_t axis = concatOp.getAxis(); - llvm::SmallVector sliceStart(sliceOp.getStart()); - llvm::ArrayRef sliceSize = sliceOp.getSize(); + DenseElementsAttr startElems; + DenseElementsAttr sizeElems; + + if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems))) + return rewriter.notifyMatchFailure( + sliceOp, "start of slice must be a static ranked shape"); + + if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) + return rewriter.notifyMatchFailure( + sliceOp, "size of slice must be a static ranked shape"); + + llvm::SmallVector sliceStarts = + llvm::to_vector(startElems.getValues()); + llvm::SmallVector sliceSizes = + llvm::to_vector(sizeElems.getValues()); // Validate slice on the concatenated axis. Slicing along this // axis should span only one of the inputs to the concatenate @@ -457,17 +470,20 @@ struct ConcatSliceOptimization : public OpRewritePattern { return rewriter.notifyMatchFailure( sliceOp, "concat input must be a static ranked tensor"); - if (sliceStart[axis] >= 0 && - (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) { - replaceWithSlice = rewriter - .create( - sliceOp.getLoc(), sliceOp.getType(), input, - rewriter.getDenseI64ArrayAttr(sliceStart), - rewriter.getDenseI64ArrayAttr(sliceSize)) - .getResult(); + if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <= + inputType.getDimSize(axis)) { + auto start_op = + getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts); + auto size_op = + getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes); + replaceWithSlice = + rewriter + .create(sliceOp.getLoc(), sliceOp.getType(), + input, start_op, size_op) + .getResult(); break; } - sliceStart[axis] -= inputType.getDimSize(axis); + sliceStarts[axis] -= inputType.getDimSize(axis); } if (!replaceWithSlice) @@ -1014,7 +1030,12 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { if (inputTy.hasStaticShape() && outputTy.hasStaticShape() && outputTy.getNumElements() == 1) { - llvm::SmallVector indices(getStart()); + DenseElementsAttr startElems; + if (!matchPattern(getStart(), m_Constant(&startElems))) + return {}; + + llvm::SmallVector indices = + llvm::to_vector(startElems.getValues()); auto value = operand.getValues()[indices]; return SplatElementsAttr::get(outputTy, value); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index fdccce60fe1d8..60555a6188cfe 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -891,8 +891,18 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, SliceOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - auto start = adaptor.getStart(); - auto size = adaptor.getSize(); + + Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); + SmallVector start; + SmallVector size; + + if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) || + !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) { + auto rank = cast(adaptor.getSize().getType()).getRank(); + SmallVector fallback(rank, ShapedType::kDynamic); + inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType)); + return success(); + } // if size[i] is -1, all remaining elements in dimension i are included // in the slice, similar to TF. @@ -933,11 +943,15 @@ LogicalResult tosa::SliceOp::verify() { if (!inputType) return success(); - if (static_cast(inputType.getRank()) != getStart().size()) + auto startShapeRank = + llvm::cast(getStart().getType()).getRank(); + if (inputType.getRank() != startShapeRank) return emitOpError( "length of start attribute is not equal rank of input shape"); - if (static_cast(inputType.getRank()) != getSize().size()) + auto sizeShapeRank = + llvm::cast(getSize().getType()).getRank(); + if (inputType.getRank() != sizeShapeRank) return emitOpError( "length of size attribute is not equal rank of input shape"); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 1b97f0b245d9b..807f9cd683bb8 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -302,8 +302,8 @@ class TransposeConvStridedConverter auto slice = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, - rewriter.getDenseI64ArrayAttr(sliceBegin), - rewriter.getDenseI64ArrayAttr(sliceSize)) + getTosaConstShape(rewriter, loc, sliceBegin), + getTosaConstShape(rewriter, loc, sliceSize)) .getResult(); llvm::SmallVector resultPadding = {0, 0, 0, 0, 0, 0, 0, 0}; diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir index 36eb4d4669b07..a72d6b333f7ea 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir @@ -2,7 +2,9 @@ // CHECK-LABEL: @slice_resultType_unranked func.func @slice_resultType_unranked(%arg0: tensor) -> (tensor<*xf32>) { + %0 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1> + %1 = tosa.const_shape {value = dense<0> : tensor<1xindex>} : () -> !tosa.shape<1> // expected-error@+1 {{failed to legalize operation 'tosa.slice'}} - %0 = "tosa.slice"(%arg0) {start = array, size = array} : (tensor) -> (tensor<*xf32>) - return %0 : tensor<*xf32> + %2 = tosa.slice %arg0, %0, %1 : (tensor, !tosa.shape<1>, !tosa.shape<1>) -> tensor<*xf32> + return %2 : tensor<*xf32> } diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index 27018fb79f60d..f95de79847464 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -437,7 +437,9 @@ func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3x // CHECK-LABEL: func @slice func.func @slice(%arg0: tensor<6xf32>) ->() { // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1] - %0 = "tosa.slice"(%arg0) {start = array, size = array} : (tensor<6xf32>) -> (tensor<1xf32>) + %0 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1> + %1 = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> + %2 = tosa.slice %arg0, %0, %1 : (tensor<6xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<1xf32> return } @@ -450,8 +452,10 @@ func.func @slice_dyn(%arg0: tensor) -> (tensor) { // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]] // CHECK: tensor.extract_slice %arg0[2] [%[[SUB]]] [1] - %0 = "tosa.slice"(%arg0) {start = array, size = array} : (tensor) -> (tensor) - return %0 : tensor + %0 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1> + %1 = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1> + %2 = tosa.slice %arg0, %0, %1 : (tensor, !tosa.shape<1>, !tosa.shape<1>) -> tensor + return %2 : tensor } // ----- diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 6f47f041b9199..094939ccf018b 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -624,18 +624,22 @@ func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3x!quant.uniform< // CHECK-LABEL: @slice_fold func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { + %0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK: return %arg0 - %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<3x4xf32> - return %0 : tensor<3x4xf32> + %3 = tosa.slice %arg0, %0, %1 : (tensor<3x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x4xf32> + return %3 : tensor<3x4xf32> } // ----- // CHECK-LABEL: @slice_nofold func.func @slice_nofold(%arg0: tensor) -> tensor { + %0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2> + %1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK: tosa.slice - %0 = tosa.slice %arg0 { size = array, start = array}: (tensor) -> tensor - return %0 : tensor + %3 = tosa.slice %arg0, %0, %1 : (tensor, !tosa.shape<2>, !tosa.shape<2>) -> tensor + return %3 : tensor } // ----- @@ -715,9 +719,12 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32> func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) { %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32> - %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32> - %2 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32> - return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32> + %1 = tosa.const_shape {value = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %2 = tosa.const_shape {value = dense<[0, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %3 = tosa.const_shape {value = dense<[1, 12, 12, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %4 = tosa.slice %0, %1, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32> + %5 = tosa.slice %0, %2, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32> + return %4, %5 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32> } // ----- @@ -727,38 +734,56 @@ func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, % // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32> func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) { %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32> - %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32> - %2 = tosa.slice %0 {size = array, start = array} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32> - return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32> + %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + %2 = tosa.const_shape {value = dense<[0, 12, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + %3 = tosa.const_shape {value = dense<[1, 12, 12]> : tensor<3xindex>} : () -> !tosa.shape<3> + %4 = tosa.slice %0, %1, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32> + %5 = tosa.slice %0, %2, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32> + return %4, %5 : tensor<1x12x12xf32>, tensor<1x12x12xf32> } // ----- // CHECK-LABEL: @canonicalize_cross_concat_inputs // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32> -// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> -// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32> -// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32> -// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 12, 20]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 12, 15]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 4]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>} +// CHECK: %[[VAL_6:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_6]], %[[VAL_5]], %[[VAL_3]] +// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_6]], %[[VAL_4]], %[[VAL_2]] +// CHECK: return %[[VAL_7]], %[[VAL_8]] : tensor<1x12x15xf32>, tensor<1x12x20xf32> func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) { %0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> - %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32> - %2 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32> - return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32> + %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + %2 = tosa.const_shape {value = dense<[0, 0, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> + %3 = tosa.const_shape {value = dense<[1, 12, 15]> : tensor<3xindex>} : () -> !tosa.shape<3> + %4 = tosa.const_shape {value = dense<[1, 12, 20]> : tensor<3xindex>} : () -> !tosa.shape<3> + %5 = tosa.slice %0, %1, %3 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x15xf32> + %6 = tosa.slice %0, %2, %4 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x20xf32> + return %5, %6 : tensor<1x12x15xf32>, tensor<1x12x20xf32> } // ----- // CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32> -// CHECK: %[[VAL_2:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32> -// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32> -// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 3, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 3, 12]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 6, 12]> : tensor<3xindex>} +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_0]], %[[VAL_4]], %[[VAL_5]] +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]], %[[VAL_2]], %[[VAL_3]] +// CHECK: return %[[VAL_6]], %[[VAL_7]] : tensor<1x6x12xf32>, tensor<1x3x12xf32> func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) { %0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> - %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32> - %2 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32> - return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32> + %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + %2 = tosa.const_shape {value = dense<[1, 6, 12]> : tensor<3xindex>} : () -> !tosa.shape<3> + %3 = tosa.const_shape {value = dense<[1, 3, 12]> : tensor<3xindex>} : () -> !tosa.shape<3> + %4 = tosa.slice %0, %1, %2 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x6x12xf32> + %5 = tosa.slice %0, %3, %3 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x3x12xf32> + return %4, %5 : tensor<1x6x12xf32>, tensor<1x3x12xf32> } // ----- diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 8198903b78ac0..32dbe3315ca0b 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -506,7 +506,10 @@ func.func @reshape_splat() -> tensor<6x5x4xi32> { func.func @slice_splat() -> tensor<1x1x1xi32> { // CHECK: %[[SLICE:.+]] = "tosa.const"() <{value = dense<42> : tensor<1x1x1xi32>} %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> - %slice = tosa.slice %splat { size = array, start = array } : (tensor<4x5x6xi32>) -> tensor<1x1x1xi32> + %start = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + %size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %slice= tosa.slice %splat, %start, %size : (tensor<4x5x6xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x1x1xi32> + // CHECK: return %[[SLICE]] return %slice : tensor<1x1x1xi32> } @@ -517,7 +520,9 @@ func.func @slice_splat() -> tensor<1x1x1xi32> { func.func @slice_singleton() -> tensor<1x1xi32> { %splat = "tosa.const"() {value = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32> // CHECK: %[[SLICE:.+]] = "tosa.const"() <{value = dense<4> : tensor<1x1xi32>} - %slice = tosa.slice %splat { size = array, start = array } : (tensor<3x3xi32>) -> tensor<1x1xi32> + %start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %size = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %slice= tosa.slice %splat, %start, %size : (tensor<3x3xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xi32> // CHECK: return %[[SLICE]] return %slice : tensor<1x1xi32> } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 4808867b28bb9..3358660d89fc1 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -607,8 +607,10 @@ func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () { func.func @test_slice_invalid_start() { %0 = tensor.empty() : tensor<4x31x31xf32> + %start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> // expected-error@+1 {{'tosa.slice' op length of start attribute is not equal rank of input shape}} - %1 = tosa.slice %0 {size = array, start = array} : (tensor<4x31x31xf32>) -> tensor<*xf32> + %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<2>, !tosa.shape<3>) -> tensor<*xf32> return } @@ -616,8 +618,10 @@ func.func @test_slice_invalid_start() { func.func @test_slice_invalid_size() { %0 = tensor.empty() : tensor<4x31x31xf32> + %start = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %size = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1> // expected-error@+1 {{'tosa.slice' op length of size attribute is not equal rank of input shape}} - %1 = tosa.slice %0 {size = array, start = array} : (tensor<4x31x31xf32>) -> tensor<*xf32> + %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<3>, !tosa.shape<1>) -> tensor<*xf32> return } diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 0fe35d88f0e73..26bebdd898a0d 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -86,10 +86,11 @@ func.func @test_reverse(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13 // ----- // CHECK-LABEL: slice func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32> { + %0 = tosa.const_shape {value = dense<[0, 0, 0, 0, 6, 8, 0]> : tensor<7xindex>} : () -> !tosa.shape<7> + %1 = tosa.const_shape {value = dense<[1, 1, 1, 1, 4, 11, 1]> : tensor<7xindex>} : () -> !tosa.shape<7> // expected-error@+1 {{'tosa.slice' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = "tosa.slice"(%arg0) {start = array, size = array} : - (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32> - return %0 : tensor<1x1x1x1x4x11x1xf32> + %2= tosa.slice %arg0, %0, %1 : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<7>, !tosa.shape<7>) -> tensor<1x1x1x1x4x11x1xf32> + return %2 : tensor<1x1x1x1x4x11x1xf32> } // ----- @@ -736,8 +737,10 @@ func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x // CHECK-LABEL: unranked_tensor func.func @test_unranked_tensor(%arg0: tensor<*xf32>) { + %0 = tosa.const_shape {value = dense<[0]> : tensor<1xindex>} : () -> !tosa.shape<1> + %1 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1> + // expected-error@+1 {{'tosa.slice' op failed level check: unranked tensor}} - %0 = "tosa.slice"(%arg0) {start = array, size = array} : - (tensor<*xf32>) -> tensor<*xf32> + %2= tosa.slice %arg0, %0, %1 : (tensor<*xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<*xf32> return } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 19b93d7611854..58000c2f041be 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -571,8 +571,19 @@ func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: slice func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { - %0 = tosa.slice %arg0 {size = array, start = array} : (tensor<13x21x3xf32>) -> tensor<4x11x1xf32> - return %0 : tensor<4x11x1xf32> + %0 = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32> + return %2 : tensor<4x11x1xf32> +} + +// ----- +// CHECK-LABEL: slice_size +func.func @test_slice_size(%arg0: tensor<13x21x3xf32>) -> tensor<7x11x1xf32> { + %0 = tosa.const_shape {value = dense<[-1, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x11x1xf32> + return %2 : tensor<7x11x1xf32> } // ----- diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 96f71c349938b..12691f2e325a2 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -43,6 +43,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: // ----- // CHECK-LABEL: @transpose_conv2d_strided + func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> { // Manipulate the weight matrix to handle striding. // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> @@ -64,9 +65,11 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor< // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array} // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] - // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array} - // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array, start = array} - // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array} + // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] + // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> + // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} : () -> !tosa.shape<4> + // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] + // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32> @@ -76,6 +79,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor< // ----- // CHECK-LABEL: @transpose_conv2d_strided_quantized + func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) { // Manipulate the weight matrix to handle striding. // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> @@ -97,9 +101,11 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array} // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] - // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array} - // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array, start = array} - // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array} + // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] + // CHECK-DAG: %[[START:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} + // CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 35, 47, 5]> : tensor<4xindex>} + // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] + // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> return %0 : tensor<2x35x47x5xi32> diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 6beb1ad629613..0c59cb089b935 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -517,8 +517,12 @@ func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () { // CHECK-LABEL: @test_slice func.func @test_slice(%arg0 : tensor) -> () { - // CHECK: tosa.slice %arg0 {size = array, start = array} : (tensor) -> tensor<2xi32> - %0 = tosa.slice %arg0 { size = array, start = array } : (tensor) -> tensor + // CHECK: %0 = tosa.const_shape {value = dense<1> : tensor<1xindex>} + // CHECK: %1 = tosa.const_shape {value = dense<2> : tensor<1xindex>} + // CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32> + %0 = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> + %1 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1> + %2= tosa.slice %arg0, %0, %1 : (tensor, !tosa.shape<1>, !tosa.shape<1>) -> tensor return } @@ -526,13 +530,17 @@ func.func @test_slice(%arg0 : tensor) -> () { // CHECK-LABEL: @test_slice_size_minus_one func.func @test_slice_size_minus_one(%arg0 : tensor) -> () { - // CHECK: tosa.slice %arg0 {size = array, start = array} : (tensor) -> tensor + // CHECK: %[[Start:.+]] = tosa.const_shape + // CHECK: %[[Size:.+]] = tosa.const_shape + // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[Start]], %[[Size]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor // this checks following // dim 0: size=-1, input dim=? => inferred output dim is ? // dim 1: size=-1 => inferred output dim is input_dim - start // dim 2: size=-1, start=-1 => inferred output dim is ? // dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound - %2= tosa.slice %arg0 { start = array, size = array } : (tensor) -> tensor + %start = tosa.const_shape {value = dense<[0, 1, -1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> + %size = tosa.const_shape {value = dense<[-1, -1, -1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %2= tosa.slice %arg0, %start, %size : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor return } @@ -540,13 +548,17 @@ func.func @test_slice_size_minus_one(%arg0 : tensor) -> () { // CHECK-LABEL: @test_slice_size_out_of_bound func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () { - // CHECK: tosa.slice %arg0 {size = array, start = array} : (tensor<8x8x8x?xi32>) -> tensor + // CHECK: %[[Start:.+]] = tosa.const_shape + // CHECK: %[[Size:.+]] = tosa.const_shape + // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[Start]], %[[Size]] : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor // this checks following // dim 0: size=0 => inferred output dim is ? // dim 1: size=-2 => inferred output dim is ? // dim 3: start+size out of bound because size too big: inferred output dim is ? // dim 4: size=4, input dim=? => inferred output dim is 4 - %2= tosa.slice %arg0 { start = array, size = array } : (tensor<8x8x8x?xi32>) -> tensor + %start = tosa.const_shape {value = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %size = tosa.const_shape {value = dense<[0, -2, 9, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> + %2= tosa.slice %arg0, %start, %size : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor return } @@ -554,13 +566,17 @@ func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () { // CHECK-LABEL: @test_slice_start_out_of_bound func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () { - // CHECK: tosa.slice %arg0 {size = array, start = array} : (tensor<8x8x8x?xi32>) -> tensor + // CHECK: %[[Start:.+]] = tosa.const_shape + // CHECK: %[[Size:.+]] = tosa.const_shape + // CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[Start]], %[[Size]] : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor // this checks following // dim 0: start=-1 => inferred output dim is ? // dim 1: start=8 => inferred output dim is ? // dim 2: start+size out of bound: inferred output dim is ? // dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4 - %2= tosa.slice %arg0 { start = array, size = array } : (tensor<8x8x8x?xi32>) -> tensor + %start = tosa.const_shape {value = dense<[-1, 8, 6, 8000000]> : tensor<4xindex>} : () -> !tosa.shape<4> + %size = tosa.const_shape {value = dense<[1, 1, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> + %2= tosa.slice %arg0, %start, %size : (tensor<8x8x8x?xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor return } @@ -568,8 +584,12 @@ func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () { // CHECK-LABEL: @test_slice_dynamic func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () { - // CHECK: tosa.slice %arg0 {size = array, start = array} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32> - %0 = tosa.slice %arg0 {size = array, start = array} : (tensor<10x?x2xf32>) -> tensor + // CHECK: %0 = tosa.const_shape {value = dense<[1, 0, 0]> : tensor<3xindex>} + // CHECK: %1 = tosa.const_shape {value = dense<[7, -1, 1]> : tensor<3xindex>} + // CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32> + %0 = tosa.const_shape {value = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + %1 = tosa.const_shape {value = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %2= tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor return }