Skip to content

Commit 956c070

Browse files
[mlir][tosa] Change the start and size of slice to tosa shape type (#124209)
Update to use getConstShapeValue to collect shape information along the graph. Change-Id: Ic6fc2341e3bcfbec06a1d08986e26dd08573bd9c Co-authored-by: TatWai Chong <[email protected]>
1 parent 46f9cdd commit 956c070

File tree

14 files changed

+220
-79
lines changed

14 files changed

+220
-79
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,8 +1714,8 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
17141714

17151715
let arguments = (ins
17161716
Tosa_Tensor:$input1,
1717-
DenseI64ArrayAttr:$start,
1718-
DenseI64ArrayAttr:$size
1717+
Tosa_Shape:$start,
1718+
Tosa_Shape:$size
17191719
);
17201720

17211721
let results = (outs

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,28 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
268268
ShapedType resultType = cast<ShapedType>(sliceOp.getType());
269269
if (llvm::isa<UnrankedTensorType>(resultType))
270270
return failure();
271+
272+
ElementsAttr startElems;
273+
ElementsAttr sizeElems;
274+
275+
if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
276+
return rewriter.notifyMatchFailure(
277+
sliceOp, "start of slice must be a static ranked shape");
278+
279+
if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
280+
return rewriter.notifyMatchFailure(
281+
sliceOp, "size of slice must be a static ranked shape");
282+
283+
llvm::SmallVector<int64_t> sliceStarts =
284+
llvm::to_vector(startElems.getValues<int64_t>());
285+
llvm::SmallVector<int64_t> sliceSizes =
286+
llvm::to_vector(sizeElems.getValues<int64_t>());
287+
271288
SmallVector<int64_t> strides, sizes;
272-
ArrayRef<int64_t> starts = sliceOp.getStart();
273289
strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
274290

275291
SmallVector<Value> dynSizes;
276-
for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
292+
for (const auto &i : llvm::enumerate(sliceSizes)) {
277293
int64_t size = i.value();
278294
size_t index = i.index();
279295
sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
@@ -282,17 +298,27 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
282298

283299
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
284300
auto offset = rewriter.create<arith::ConstantOp>(
285-
loc, rewriter.getIndexAttr(starts[index]));
301+
loc, rewriter.getIndexAttr(sliceStarts[index]));
286302
dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
287303
}
288304

289305
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
290306
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
291-
ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
307+
ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
292308
rewriter.getDenseI64ArrayAttr(sizes),
293309
rewriter.getDenseI64ArrayAttr(strides));
294310

295311
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
312+
313+
// Remove const_shape ops when it no longer has use point.
314+
Operation *startConstShape = sliceOp.getStart().getDefiningOp();
315+
if (startConstShape->getResult(0).hasOneUse())
316+
rewriter.eraseOp(startConstShape);
317+
318+
Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
319+
if (sizeConstShape->getResult(0).hasOneUse())
320+
rewriter.eraseOp(sizeConstShape);
321+
296322
return success();
297323
}
298324
};

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,21 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
444444
sliceOp, "slice input must be a static ranked tensor");
445445
int32_t axis = concatOp.getAxis();
446446

447-
llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
448-
llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
447+
DenseElementsAttr startElems;
448+
DenseElementsAttr sizeElems;
449+
450+
if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
451+
return rewriter.notifyMatchFailure(
452+
sliceOp, "start of slice must be a static ranked shape");
453+
454+
if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
455+
return rewriter.notifyMatchFailure(
456+
sliceOp, "size of slice must be a static ranked shape");
457+
458+
llvm::SmallVector<int64_t> sliceStarts =
459+
llvm::to_vector(startElems.getValues<int64_t>());
460+
llvm::SmallVector<int64_t> sliceSizes =
461+
llvm::to_vector(sizeElems.getValues<int64_t>());
449462

450463
// Validate slice on the concatenated axis. Slicing along this
451464
// axis should span only one of the inputs to the concatenate
@@ -457,17 +470,20 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
457470
return rewriter.notifyMatchFailure(
458471
sliceOp, "concat input must be a static ranked tensor");
459472

460-
if (sliceStart[axis] >= 0 &&
461-
(sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
462-
replaceWithSlice = rewriter
463-
.create<tosa::SliceOp>(
464-
sliceOp.getLoc(), sliceOp.getType(), input,
465-
rewriter.getDenseI64ArrayAttr(sliceStart),
466-
rewriter.getDenseI64ArrayAttr(sliceSize))
467-
.getResult();
473+
if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
474+
inputType.getDimSize(axis)) {
475+
auto start_op =
476+
getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
477+
auto size_op =
478+
getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
479+
replaceWithSlice =
480+
rewriter
481+
.create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
482+
input, start_op, size_op)
483+
.getResult();
468484
break;
469485
}
470-
sliceStart[axis] -= inputType.getDimSize(axis);
486+
sliceStarts[axis] -= inputType.getDimSize(axis);
471487
}
472488

473489
if (!replaceWithSlice)
@@ -1025,7 +1041,12 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
10251041

10261042
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
10271043
outputTy.getNumElements() == 1) {
1028-
llvm::SmallVector<uint64_t> indices(getStart());
1044+
DenseElementsAttr startElems;
1045+
if (!matchPattern(getStart(), m_Constant(&startElems)))
1046+
return {};
1047+
1048+
llvm::SmallVector<uint64_t> indices =
1049+
llvm::to_vector(startElems.getValues<uint64_t>());
10291050
auto value = operand.getValues<Attribute>()[indices];
10301051
return SplatElementsAttr::get(outputTy, value);
10311052
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -891,8 +891,18 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
891891
MLIRContext *context, ::std::optional<Location> location,
892892
SliceOp::Adaptor adaptor,
893893
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
894-
auto start = adaptor.getStart();
895-
auto size = adaptor.getSize();
894+
895+
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
896+
SmallVector<int64_t> start;
897+
SmallVector<int64_t> size;
898+
899+
if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
900+
!tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
901+
auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
902+
SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
903+
inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
904+
return success();
905+
}
896906

897907
// if size[i] is -1, all remaining elements in dimension i are included
898908
// in the slice, similar to TF.
@@ -933,11 +943,15 @@ LogicalResult tosa::SliceOp::verify() {
933943
if (!inputType)
934944
return success();
935945

936-
if (static_cast<size_t>(inputType.getRank()) != getStart().size())
946+
auto startShapeRank =
947+
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
948+
if (inputType.getRank() != startShapeRank)
937949
return emitOpError(
938950
"length of start attribute is not equal rank of input shape");
939951

940-
if (static_cast<size_t>(inputType.getRank()) != getSize().size())
952+
auto sizeShapeRank =
953+
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
954+
if (inputType.getRank() != sizeShapeRank)
941955
return emitOpError(
942956
"length of size attribute is not equal rank of input shape");
943957

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ class TransposeConvStridedConverter
302302

303303
auto slice = CreateOpAndInferShape<tosa::SliceOp>(
304304
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
305-
rewriter.getDenseI64ArrayAttr(sliceBegin),
306-
rewriter.getDenseI64ArrayAttr(sliceSize))
305+
getTosaConstShape(rewriter, loc, sliceBegin),
306+
getTosaConstShape(rewriter, loc, sliceSize))
307307
.getResult();
308308

309309
llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};

mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
// CHECK-LABEL: @slice_resultType_unranked
44
func.func @slice_resultType_unranked(%arg0: tensor<?xf32>) -> (tensor<*xf32>) {
5+
%0 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
6+
%1 = tosa.const_shape {value = dense<0> : tensor<1xindex>} : () -> !tosa.shape<1>
57
// expected-error@+1 {{failed to legalize operation 'tosa.slice'}}
6-
%0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 0>} : (tensor<?xf32>) -> (tensor<*xf32>)
7-
return %0 : tensor<*xf32>
8+
%2 = tosa.slice %arg0, %0, %1 : (tensor<?xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<*xf32>
9+
return %2 : tensor<*xf32>
810
}

mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,9 @@ func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3x
437437
// CHECK-LABEL: func @slice
438438
func.func @slice(%arg0: tensor<6xf32>) ->() {
439439
// CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
440-
%0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 1>} : (tensor<6xf32>) -> (tensor<1xf32>)
440+
%0 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
441+
%1 = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
442+
%2 = tosa.slice %arg0, %0, %1 : (tensor<6xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<1xf32>
441443
return
442444
}
443445

@@ -450,8 +452,10 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
450452
// CHECK: %[[C2:.+]] = arith.constant 2 : index
451453
// CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]]
452454
// CHECK: tensor.extract_slice %arg0[2] [%[[SUB]]] [1]
453-
%0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: -1>} : (tensor<?xf32>) -> (tensor<?xf32>)
454-
return %0 : tensor<?xf32>
455+
%0 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
456+
%1 = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
457+
%2 = tosa.slice %arg0, %0, %1 : (tensor<?xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<?xf32>
458+
return %2 : tensor<?xf32>
455459
}
456460

457461
// -----

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -638,18 +638,22 @@ func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3x!quant.uniform<
638638

639639
// CHECK-LABEL: @slice_fold
640640
func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
641+
%0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
642+
%1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
641643
// CHECK: return %arg0
642-
%0 = tosa.slice %arg0 { size = array<i64: 3, 4>, start = array<i64: 0, 0>}: (tensor<3x4xf32>) -> tensor<3x4xf32>
643-
return %0 : tensor<3x4xf32>
644+
%3 = tosa.slice %arg0, %0, %1 : (tensor<3x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x4xf32>
645+
return %3 : tensor<3x4xf32>
644646
}
645647

646648
// -----
647649

648650
// CHECK-LABEL: @slice_nofold
649651
func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
652+
%0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
653+
%1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
650654
// CHECK: tosa.slice
651-
%0 = tosa.slice %arg0 { size = array<i64: 3, 4>, start = array<i64: 0, 0>}: (tensor<?x4xf32>) -> tensor<?x4xf32>
652-
return %0 : tensor<?x4xf32>
655+
%3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
656+
return %3 : tensor<?x4xf32>
653657
}
654658

655659
// -----
@@ -729,9 +733,12 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
729733
// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
730734
func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) {
731735
%0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32>
732-
%1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
733-
%2 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
734-
return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
736+
%1 = tosa.const_shape {value = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
737+
%2 = tosa.const_shape {value = dense<[0, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
738+
%3 = tosa.const_shape {value = dense<[1, 12, 12, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
739+
%4 = tosa.slice %0, %1, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32>
740+
%5 = tosa.slice %0, %2, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32>
741+
return %4, %5 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
735742
}
736743

737744
// -----
@@ -741,38 +748,56 @@ func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %
741748
// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32>
742749
func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) {
743750
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32>
744-
%1 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
745-
%2 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 12, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
746-
return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
751+
%1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
752+
%2 = tosa.const_shape {value = dense<[0, 12, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
753+
%3 = tosa.const_shape {value = dense<[1, 12, 12]> : tensor<3xindex>} : () -> !tosa.shape<3>
754+
%4 = tosa.slice %0, %1, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32>
755+
%5 = tosa.slice %0, %2, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32>
756+
return %4, %5 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
747757
}
748758

749759
// -----
750760

751761
// CHECK-LABEL: @canonicalize_cross_concat_inputs
752762
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
753-
// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
754-
// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
755-
// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
756-
// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
763+
// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 12, 20]> : tensor<3xindex>}
764+
// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 12, 15]> : tensor<3xindex>}
765+
// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 4]> : tensor<3xindex>}
766+
// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>}
767+
// CHECK: %[[VAL_6:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
768+
// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_6]], %[[VAL_5]], %[[VAL_3]]
769+
// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_6]], %[[VAL_4]], %[[VAL_2]]
770+
// CHECK: return %[[VAL_7]], %[[VAL_8]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
757771
func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) {
758772
%0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
759-
%1 = tosa.slice %0 {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
760-
%2 = tosa.slice %0 {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
761-
return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
773+
%1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
774+
%2 = tosa.const_shape {value = dense<[0, 0, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
775+
%3 = tosa.const_shape {value = dense<[1, 12, 15]> : tensor<3xindex>} : () -> !tosa.shape<3>
776+
%4 = tosa.const_shape {value = dense<[1, 12, 20]> : tensor<3xindex>} : () -> !tosa.shape<3>
777+
%5 = tosa.slice %0, %1, %3 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x15xf32>
778+
%6 = tosa.slice %0, %2, %4 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x20xf32>
779+
return %5, %6 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
762780
}
763781

764782
// -----
765783

766784
// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
767785
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
768-
// CHECK: %[[VAL_2:.*]] = tosa.slice %[[VAL_0]] {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
769-
// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 0>} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
770-
// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
786+
// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 3, 0]> : tensor<3xindex>}
787+
// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 3, 12]> : tensor<3xindex>}
788+
// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>}
789+
// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 6, 12]> : tensor<3xindex>}
790+
// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_0]], %[[VAL_4]], %[[VAL_5]]
791+
// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]
792+
// CHECK: return %[[VAL_6]], %[[VAL_7]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
771793
func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
772794
%0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
773-
%1 = tosa.slice %0 {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32>
774-
%2 = tosa.slice %0 {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32>
775-
return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32>
795+
%1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
796+
%2 = tosa.const_shape {value = dense<[1, 6, 12]> : tensor<3xindex>} : () -> !tosa.shape<3>
797+
%3 = tosa.const_shape {value = dense<[1, 3, 12]> : tensor<3xindex>} : () -> !tosa.shape<3>
798+
%4 = tosa.slice %0, %1, %2 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x6x12xf32>
799+
%5 = tosa.slice %0, %3, %3 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x3x12xf32>
800+
return %4, %5 : tensor<1x6x12xf32>, tensor<1x3x12xf32>
776801
}
777802

778803
// -----

0 commit comments

Comments
 (0)