diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index f11162dc0d95e..48764580d526d 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -1205,6 +1205,7 @@ def hlfir_ShapeOfOp : hlfir_Op<"shape_of", [Pure]> { }]; let builders = [OpBuilder<(ins "mlir::Value":$expr)>]; + let hasFolder = 1; } def hlfir_GetExtentOp : hlfir_Op<"get_extent", [Pure]> { diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index ad53527f43441..82aac7cafa1d0 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -1704,6 +1704,15 @@ hlfir::ShapeOfOp::canonicalize(ShapeOfOp shapeOf, return llvm::LogicalResult::success(); } +mlir::OpFoldResult hlfir::ShapeOfOp::fold(FoldAdaptor adaptor) { + if (matchPattern(getExpr(), mlir::m_Op())) { + auto elementalOp = + mlir::cast(getExpr().getDefiningOp()); + return elementalOp.getShape(); + } + return {}; +} + //===----------------------------------------------------------------------===// // GetExtent //===----------------------------------------------------------------------===// diff --git a/flang/test/HLFIR/shapeof.fir b/flang/test/HLFIR/shapeof.fir index b91efc276b62e..43e22dd320c18 100644 --- a/flang/test/HLFIR/shapeof.fir +++ b/flang/test/HLFIR/shapeof.fir @@ -27,3 +27,21 @@ func.func @shapeof2(%arg0: !hlfir.expr) -> !fir.shape<2> { // CHECK-ALL: %[[EXPR:.*]]: !hlfir.expr // CHECK-ALL-NEXT: %[[SHAPE:.*]] = hlfir.shape_of %[[EXPR]] : (!hlfir.expr) -> !fir.shape<2> // CHECK-ALL-NEXT: return %[[SHAPE]] + +// Checks hlfir.elemental -> hlfir.shape_of folding +func.func @shapeof_fold1(%extent: index) -> !fir.shape<1> { + %shape1 = fir.shape %extent : (index) -> !fir.shape<1> + %elem = hlfir.elemental %shape1 : (!fir.shape<1>) -> !hlfir.expr { + hlfir.yield_element %extent : index + } + %shape2 = hlfir.shape_of %elem : (!hlfir.expr) -> !fir.shape<1> + return %shape2 : !fir.shape<1> +} +// CHECK-ALL-LABEL: func.func @shapeof_fold1( +// CHECK-ALL-SAME: %[[VAL_0:.*]]: index) -> !fir.shape<1> { +// CHECK-CANON-NEXT: %[[VAL_1:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1> +// CHECK-CANON-NEXT: %[[VAL_2:.*]] = hlfir.elemental %[[VAL_1]] : (!fir.shape<1>) -> !hlfir.expr { +// CHECK-CANON-NEXT: hlfir.yield_element %[[VAL_0]] : index +// CHECK-CANON-NEXT: } +// CHECK-CANON-NEXT: return %[[VAL_1]] : !fir.shape<1> +// CHECK-CANON-NEXT: }