diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 10ba808cd26c2..f670614806dbd 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1734,10 +1734,23 @@ struct ShapeOfFromReshape : public OpRewritePattern { // Operand 'shape' of 'tensor.reshape' may now be used as the result of // 'shape.shape_of'. While its type is guaranteed to be compatible in well- // formed IR, it may not be identical (dynamically vs statically shaped), - // in which case it needs to be cast first. + // in which case it needs to be cast first using 'tensor.cast'. + // Additionally, it may not have identical element type (i32 vs index) + // while it has identical shaped type (dynamic vs static), in which case it + // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of' + // op result must be shape or extent tensor. Value shape = tensorReshapeOp.getShape(); - if (op.getType() != shape.getType()) - shape = rewriter.create(op.getLoc(), op.getType(), shape); + + auto opTensorTy = cast(op.getType()); + auto shapeTensorTy = cast(shape.getType()); + + if (opTensorTy != shapeTensorTy) { + if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) + shape = rewriter.create(op.getLoc(), opTensorTy, shape); + else if (!isExtentTensorType(shapeTensorTy)) + shape = + rewriter.create(op.getLoc(), opTensorTy, shape); + } rewriter.replaceOp(op, shape); return success(); diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index cf439c9c1b854..b42fa75e4112d 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1389,10 +1389,25 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor) - // ----- -// CHECK-LABEL: func @shape_of_from_reshape_compatible_types +// Check statically shaped types, with element types i32 to index. +// CHECK-LABEL: func @shape_of_from_reshape_int_to_index +// CHECK-SAME: %[[INPUT:.*]]: tensor +// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> +func.func @shape_of_from_reshape_int_to_index(%arg0: tensor, %arg1: tensor<3xi32>) -> tensor<3xindex> { + // CHECK: %[[CAST_SHAPE:.*]] = arith.index_cast %[[SHAPE]] : tensor<3xi32> to tensor<3xindex> + // CHECK: return %[[CAST_SHAPE]] : tensor<3xindex> + %0 = tensor.reshape %arg0(%arg1) : (tensor, tensor<3xi32>) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor<3xindex> + return %1 : tensor<3xindex> +} + +// ----- + +// Check similar element types, with statically shaped to dynamically shaped. +// CHECK-LABEL: func @shape_of_from_reshape_static_to_dynamic // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> // CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex> -func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor { +func.func @shape_of_from_reshape_static_to_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor { // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor // CHECK: return %[[CAST_SHAPE]] : tensor %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32> @@ -1402,6 +1417,33 @@ func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: t // ----- +// Check similar element types, with dynamically shaped to statically shaped. +// CHECK-LABEL: func @shape_of_from_reshape_dynamic_to_static +// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> +// CHECK-SAME: %[[SHAPE:.*]]: tensor +func.func @shape_of_from_reshape_dynamic_to_static(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<5xindex> { + // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor to tensor<5xindex> + // CHECK: return %[[CAST_SHAPE]] : tensor<5xindex> + %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<5xindex> + return %1 : tensor<5xindex> +} + +// ----- + +// Check similar element types and similar static shape. +// CHECK-LABEL: func @shape_of_from_reshape_identical_types +// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> +// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex> +func.func @shape_of_from_reshape_identical_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<5xindex> { + // CHECK: return %[[SHAPE]] : tensor<5xindex> + %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32> + %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<5xindex> + return %1 : tensor<5xindex> +} + +// ----- + // CHECK-LABEL: func @shape_of_from_reshape_nofold // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> // CHECK-SAME: %[[SHAPE:.*]]: tensor