@@ -656,7 +656,7 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
656
656
unsigned rank, int maskRank,
657
657
mlir::Type elementType,
658
658
mlir::Type maskElemType,
659
- mlir::Type resultElemTy) {
659
+ mlir::Type resultElemTy, bool isDim ) {
660
660
auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
661
661
mlir::Type elementType) {
662
662
if (auto ty = elementType.dyn_cast <mlir::FloatType>()) {
@@ -858,16 +858,27 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
858
858
maskElemType, resultArr, maskRank == 0 );
859
859
860
860
// Store newly created output array to the reference passed in
861
- fir::SequenceType::Shape resultShape (1 , rank);
862
- mlir::Type outputArrTy = fir::SequenceType::get (resultShape, resultElemTy);
863
- mlir::Type outputHeapTy = fir::HeapType::get (outputArrTy);
864
- mlir::Type outputBoxTy = fir::BoxType::get (outputHeapTy);
865
- mlir::Type outputRefTy = builder.getRefType (outputBoxTy);
866
- mlir::Value outputArr = builder.create <fir::ConvertOp>(
867
- loc, outputRefTy, funcOp.front ().getArgument (0 ));
868
-
869
- // Store nearly created array to output array
870
- builder.create <fir::StoreOp>(loc, resultArr, outputArr);
861
+ if (isDim) {
862
+ mlir::Type resultBoxTy =
863
+ fir::BoxType::get (fir::HeapType::get (resultElemTy));
864
+ mlir::Value outputArr = builder.create <fir::ConvertOp>(
865
+ loc, builder.getRefType (resultBoxTy), funcOp.front ().getArgument (0 ));
866
+ mlir::Value resultArrScalar = builder.create <fir::ConvertOp>(
867
+ loc, fir::HeapType::get (resultElemTy), resultArrInit);
868
+ mlir::Value resultBox =
869
+ builder.create <fir::EmboxOp>(loc, resultBoxTy, resultArrScalar);
870
+ builder.create <fir::StoreOp>(loc, resultBox, outputArr);
871
+ } else {
872
+ fir::SequenceType::Shape resultShape (1 , rank);
873
+ mlir::Type outputArrTy = fir::SequenceType::get (resultShape, resultElemTy);
874
+ mlir::Type outputHeapTy = fir::HeapType::get (outputArrTy);
875
+ mlir::Type outputBoxTy = fir::BoxType::get (outputHeapTy);
876
+ mlir::Type outputRefTy = builder.getRefType (outputBoxTy);
877
+ mlir::Value outputArr = builder.create <fir::ConvertOp>(
878
+ loc, outputRefTy, funcOp.front ().getArgument (0 ));
879
+ builder.create <fir::StoreOp>(loc, resultArr, outputArr);
880
+ }
881
+
871
882
builder.create <mlir::func::ReturnOp>(loc);
872
883
}
873
884
@@ -1146,11 +1157,14 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1146
1157
1147
1158
mlir::Operation::operand_range args = call.getArgs ();
1148
1159
1149
- mlir::Value back = args[6 ];
1160
+ mlir::SymbolRefAttr callee = call.getCalleeAttr ();
1161
+ mlir::StringRef funcNameBase = callee.getLeafReference ().getValue ();
1162
+ bool isDim = funcNameBase.ends_with (" Dim" );
1163
+ mlir::Value back = args[isDim ? 7 : 6 ];
1150
1164
if (isTrueOrNotConstant (back))
1151
1165
return ;
1152
1166
1153
- mlir::Value mask = args[5 ];
1167
+ mlir::Value mask = args[isDim ? 6 : 5 ];
1154
1168
mlir::Value maskDef = findMaskDef (mask);
1155
1169
1156
1170
// maskDef is set to NULL when the defining op is not one we accept.
@@ -1159,10 +1173,8 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1159
1173
if (maskDef == NULL )
1160
1174
return ;
1161
1175
1162
- mlir::SymbolRefAttr callee = call.getCalleeAttr ();
1163
- mlir::StringRef funcNameBase = callee.getLeafReference ().getValue ();
1164
1176
unsigned rank = getDimCount (args[1 ]);
1165
- if (funcNameBase. ends_with ( " Dim " ) || !(rank > 0 ))
1177
+ if ((isDim && rank != 1 ) || !(rank > 0 ))
1166
1178
return ;
1167
1179
1168
1180
fir::FirOpBuilder builder{getSimplificationBuilder (call, kindMap)};
@@ -1203,22 +1215,24 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1203
1215
1204
1216
llvm::raw_string_ostream nameOS (funcName);
1205
1217
outType.print (nameOS);
1218
+ if (isDim)
1219
+ nameOS << ' _' << inputType;
1206
1220
nameOS << ' _' << fmfString;
1207
1221
1208
1222
auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
1209
1223
return genRuntimeMinlocType (builder, rank);
1210
1224
};
1211
1225
auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1212
- isMax](fir::FirOpBuilder &builder,
1213
- mlir::func::FuncOp &funcOp) {
1226
+ isMax, isDim ](fir::FirOpBuilder &builder,
1227
+ mlir::func::FuncOp &funcOp) {
1214
1228
genRuntimeMinMaxlocBody (builder, funcOp, isMax, rank, maskRank, inputType,
1215
- logicalElemType, outType);
1229
+ logicalElemType, outType, isDim );
1216
1230
};
1217
1231
1218
1232
mlir::func::FuncOp newFunc =
1219
1233
getOrCreateFunction (builder, funcName, typeGenerator, bodyGenerator);
1220
1234
builder.create <fir::CallOp>(loc, newFunc,
1221
- mlir::ValueRange{args[0 ], args[1 ], args[ 5 ] });
1235
+ mlir::ValueRange{args[0 ], args[1 ], mask });
1222
1236
call->dropAllReferences ();
1223
1237
call->erase ();
1224
1238
}
0 commit comments