Skip to content

Commit 2a95fe4

Browse files
authored
[Flang] Allow Intrinsic simpification with min/maxloc dim and scalar result (#81619)
This makes an adjustment to the existing fir minloc/maxloc generation code to handle functions with a dim=1 that produce a scalar result. This should allow us to get the same benefits as the existing generated minmax reductions. This is a recommit of #76194 with an extra alteration to the end of genRuntimeMinMaxlocBody to make sure we convert the output array to the correct type (a `box<heap<i32>>`, not `box<heap<array<1xi32>>>`) to prevent writing the wrong type of box into it. This still allocates the data as a `array<1xi32>`, converting it into a i32 assuming that is safe. An alternative would be to allocate the data as a i32 and change more of the accesses to it throughout genRuntimeMinMaxlocBody.
1 parent ca827d5 commit 2a95fe4

File tree

2 files changed

+92
-27
lines changed

2 files changed

+92
-27
lines changed

flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
656656
unsigned rank, int maskRank,
657657
mlir::Type elementType,
658658
mlir::Type maskElemType,
659-
mlir::Type resultElemTy) {
659+
mlir::Type resultElemTy, bool isDim) {
660660
auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
661661
mlir::Type elementType) {
662662
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
@@ -858,16 +858,27 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
858858
maskElemType, resultArr, maskRank == 0);
859859

860860
// Store newly created output array to the reference passed in
861-
fir::SequenceType::Shape resultShape(1, rank);
862-
mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
863-
mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
864-
mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
865-
mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
866-
mlir::Value outputArr = builder.create<fir::ConvertOp>(
867-
loc, outputRefTy, funcOp.front().getArgument(0));
868-
869-
// Store nearly created array to output array
870-
builder.create<fir::StoreOp>(loc, resultArr, outputArr);
861+
if (isDim) {
862+
mlir::Type resultBoxTy =
863+
fir::BoxType::get(fir::HeapType::get(resultElemTy));
864+
mlir::Value outputArr = builder.create<fir::ConvertOp>(
865+
loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0));
866+
mlir::Value resultArrScalar = builder.create<fir::ConvertOp>(
867+
loc, fir::HeapType::get(resultElemTy), resultArrInit);
868+
mlir::Value resultBox =
869+
builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar);
870+
builder.create<fir::StoreOp>(loc, resultBox, outputArr);
871+
} else {
872+
fir::SequenceType::Shape resultShape(1, rank);
873+
mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
874+
mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
875+
mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
876+
mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
877+
mlir::Value outputArr = builder.create<fir::ConvertOp>(
878+
loc, outputRefTy, funcOp.front().getArgument(0));
879+
builder.create<fir::StoreOp>(loc, resultArr, outputArr);
880+
}
881+
871882
builder.create<mlir::func::ReturnOp>(loc);
872883
}
873884

@@ -1146,11 +1157,14 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
11461157

11471158
mlir::Operation::operand_range args = call.getArgs();
11481159

1149-
mlir::Value back = args[6];
1160+
mlir::SymbolRefAttr callee = call.getCalleeAttr();
1161+
mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1162+
bool isDim = funcNameBase.ends_with("Dim");
1163+
mlir::Value back = args[isDim ? 7 : 6];
11501164
if (isTrueOrNotConstant(back))
11511165
return;
11521166

1153-
mlir::Value mask = args[5];
1167+
mlir::Value mask = args[isDim ? 6 : 5];
11541168
mlir::Value maskDef = findMaskDef(mask);
11551169

11561170
// maskDef is set to NULL when the defining op is not one we accept.
@@ -1159,10 +1173,8 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
11591173
if (maskDef == NULL)
11601174
return;
11611175

1162-
mlir::SymbolRefAttr callee = call.getCalleeAttr();
1163-
mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
11641176
unsigned rank = getDimCount(args[1]);
1165-
if (funcNameBase.ends_with("Dim") || !(rank > 0))
1177+
if ((isDim && rank != 1) || !(rank > 0))
11661178
return;
11671179

11681180
fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
@@ -1203,22 +1215,24 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
12031215

12041216
llvm::raw_string_ostream nameOS(funcName);
12051217
outType.print(nameOS);
1218+
if (isDim)
1219+
nameOS << '_' << inputType;
12061220
nameOS << '_' << fmfString;
12071221

12081222
auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
12091223
return genRuntimeMinlocType(builder, rank);
12101224
};
12111225
auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1212-
isMax](fir::FirOpBuilder &builder,
1213-
mlir::func::FuncOp &funcOp) {
1226+
isMax, isDim](fir::FirOpBuilder &builder,
1227+
mlir::func::FuncOp &funcOp) {
12141228
genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
1215-
logicalElemType, outType);
1229+
logicalElemType, outType, isDim);
12161230
};
12171231

12181232
mlir::func::FuncOp newFunc =
12191233
getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
12201234
builder.create<fir::CallOp>(loc, newFunc,
1221-
mlir::ValueRange{args[0], args[1], args[5]});
1235+
mlir::ValueRange{args[0], args[1], mask});
12221236
call->dropAllReferences();
12231237
call->erase();
12241238
}

flang/test/Transforms/simplifyintrinsics.fir

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,13 +2098,13 @@ func.func @_QPtestminloc_doesntwork1d_back(%arg0: !fir.ref<!fir.array<10xi32>> {
20982098
// CHECK-NOT: fir.call @_FortranAMinlocInteger4x1_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
20992099

21002100
// -----
2101-
// Check Minloc is not simplified when DIM arg is set
2101+
// Check Minloc is simplified when DIM arg is set so long as the result is scalar
21022102

2103-
func.func @_QPtestminloc_doesntwork1d_dim(%arg0: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "a"}) -> !fir.array<1xi32> {
2103+
func.func @_QPtestminloc_1d_dim(%arg0: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "a"}) -> !fir.array<1xi32> {
21042104
%0 = fir.alloca !fir.box<!fir.heap<i32>>
21052105
%c10 = arith.constant 10 : index
21062106
%c1 = arith.constant 1 : index
2107-
%1 = fir.alloca !fir.array<1xi32> {bindc_name = "testminloc_doesntwork1d_dim", uniq_name = "_QFtestminloc_doesntwork1d_dimEtestminloc_doesntwork1d_dim"}
2107+
%1 = fir.alloca !fir.array<1xi32> {bindc_name = "testminloc_1d_dim", uniq_name = "_QFtestminloc_1d_dimEtestminloc_1d_dim"}
21082108
%2 = fir.shape %c1 : (index) -> !fir.shape<1>
21092109
%3 = fir.array_load %1(%2) : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>) -> !fir.array<1xi32>
21102110
%4 = fir.shape %c10 : (index) -> !fir.shape<1>
@@ -2139,11 +2139,62 @@ func.func @_QPtestminloc_doesntwork1d_dim(%arg0: !fir.ref<!fir.array<10xi32>> {f
21392139
%21 = fir.load %1 : !fir.ref<!fir.array<1xi32>>
21402140
return %21 : !fir.array<1xi32>
21412141
}
2142-
// CHECK-LABEL: func.func @_QPtestminloc_doesntwork1d_dim(
2142+
// CHECK-LABEL: func.func @_QPtestminloc_1d_dim(
21432143
// CHECK-SAME: %[[ARR:.*]]: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "a"}) -> !fir.array<1xi32> {
2144-
// CHECK-NOT: fir.call @_FortranAMinlocDimx1_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
2145-
// CHECK: fir.call @_FortranAMinlocDim({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32, !fir.box<none>, i1) -> none
2146-
// CHECK-NOT: fir.call @_FortranAMinlocDimx1_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
2144+
// CHECK: fir.call @_FortranAMinlocDimx1_i32_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
2145+
2146+
// CHECK-LABEL: func.func private @_FortranAMinlocDimx1_i32_i32_contract_simplified(%arg0: !fir.ref<!fir.box<none>>, %arg1: !fir.box<none>, %arg2: !fir.box<none>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
2147+
// CHECK-NEXT: %[[V0:.*]] = fir.alloca i32
2148+
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
2149+
// CHECK-NEXT: %c1 = arith.constant 1 : index
2150+
// CHECK-NEXT: %[[V1:.*]] = fir.allocmem !fir.array<1xi32>
2151+
// CHECK-NEXT: %[[V2:.*]] = fir.shape %c1 : (index) -> !fir.shape<1>
2152+
// CHECK-NEXT: %[[V3:.*]] = fir.embox %[[V1]](%[[V2]]) : (!fir.heap<!fir.array<1xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<1xi32>>>
2153+
// CHECK-NEXT: %c0 = arith.constant 0 : index
2154+
// CHECK-NEXT: %[[V4:.*]] = fir.coordinate_of %[[V3]], %c0 : (!fir.box<!fir.heap<!fir.array<1xi32>>>, index) -> !fir.ref<i32>
2155+
// CHECK-NEXT: fir.store %c0_i32 to %[[V4]] : !fir.ref<i32>
2156+
// CHECK-NEXT: %c0_0 = arith.constant 0 : index
2157+
// CHECK-NEXT: %[[V5:.*]] = fir.convert %arg1 : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
2158+
// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
2159+
// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32
2160+
// CHECK-NEXT: fir.store %c0_i32_1 to %[[V0]] : !fir.ref<i32>
2161+
// CHECK-NEXT: %c2147483647_i32 = arith.constant 2147483647 : i32
2162+
// CHECK-NEXT: %c1_2 = arith.constant 1 : index
2163+
// CHECK-NEXT: %c0_3 = arith.constant 0 : index
2164+
// CHECK-NEXT: %[[V6:.*]]:3 = fir.box_dims %[[V5]], %c0_3 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
2165+
// CHECK-NEXT: %[[V7:.*]] = arith.subi %[[V6]]#1, %c1_2 : index
2166+
// CHECK-NEXT: %[[V8:.*]] = fir.do_loop %arg3 = %c0_0 to %[[V7]] step %c1_2 iter_args(%arg4 = %c2147483647_i32) -> (i32) {
2167+
// CHECK-NEXT: %c1_i32_4 = arith.constant 1 : i32
2168+
// CHECK-NEXT: %[[ISFIRST:.*]] = fir.load %[[FLAG_ALLOC]] : !fir.ref<i32>
2169+
// CHECK-NEXT: %[[V12:.*]] = fir.coordinate_of %[[V5]], %arg3 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
2170+
// CHECK-NEXT: %[[V13:.*]] = fir.load %[[V12]] : !fir.ref<i32>
2171+
// CHECK-NEXT: %[[V14:.*]] = arith.cmpi slt, %[[V13]], %arg4 : i32
2172+
// CHECK-NEXT: %[[ISFIRSTL:.*]] = fir.convert %[[ISFIRST]] : (i32) -> i1
2173+
// CHECK-NEXT: %true = arith.constant true
2174+
// CHECK-NEXT: %[[ISFIRSTNOT:.*]] = arith.xori %[[ISFIRSTL]], %true : i1
2175+
// CHECK-NEXT: %[[ORCOND:.*]] = arith.ori %[[V14]], %[[ISFIRSTNOT]] : i1
2176+
// CHECK-NEXT: %[[V15:.*]] = fir.if %[[ORCOND]] -> (i32) {
2177+
// CHECK-NEXT: fir.store %c1_i32_4 to %[[V0]] : !fir.ref<i32>
2178+
// CHECK-NEXT: %c1_i32_5 = arith.constant 1 : i32
2179+
// CHECK-NEXT: %c0_6 = arith.constant 0 : index
2180+
// CHECK-NEXT: %[[V16:.*]] = fir.coordinate_of %[[V3]], %c0_6 : (!fir.box<!fir.heap<!fir.array<1xi32>>>, index) -> !fir.ref<i32>
2181+
// CHECK-NEXT: %[[V17:.*]] = fir.convert %arg3 : (index) -> i32
2182+
// CHECK-NEXT: %[[V18:.*]] = arith.addi %[[V17]], %c1_i32_5 : i32
2183+
// CHECK-NEXT: fir.store %[[V18]] to %[[V16]] : !fir.ref<i32>
2184+
// CHECK-NEXT: fir.result %[[V13]] : i32
2185+
// CHECK-NEXT: } else {
2186+
// CHECK-NEXT: fir.result %arg4 : i32
2187+
// CHECK-NEXT: }
2188+
// CHECK-NEXT: fir.result %[[V15]] : i32
2189+
// CHECK-NEXT: }
2190+
// CHECK-NEXT: %[[V11:.*]] = fir.convert %arg0 : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<i32>>>
2191+
// CHECK-NEXT: %[[V12:.*]] = fir.convert %[[V1]] : (!fir.heap<!fir.array<1xi32>>) -> !fir.heap<i32>
2192+
// CHECK-NEXT: %[[V13:.*]] = fir.embox %[[V12]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
2193+
// CHECK-NEXT: fir.store %[[V13]] to %[[V11]] : !fir.ref<!fir.box<!fir.heap<i32>>>
2194+
// CHECK-NEXT: return
2195+
// CHECK-NEXT: }
2196+
2197+
21472198

21482199
// -----
21492200
// Check Minloc is not simplified when dimension of inputArr is unknown

0 commit comments

Comments
 (0)