Skip to content

Commit 2318bc8

Browse files
authored
[flang][hlfir] Add hlfir.maxval intrinsic (#65705)
Adds a new HLFIR operation for the MAXVAL intrinsic according to the design set out in flang/docs/HighLevelFIR.md.
1 parent c8387a3 commit 2318bc8

File tree

9 files changed

+966
-6
lines changed

9 files changed

+966
-6
lines changed

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,33 @@ def hlfir_CountOp : hlfir_Op<"count", [AttrSizedOperandSegments, DeclareOpInterf
404404
let hasVerifier = 1;
405405
}
406406

407+
def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
408+
DeclareOpInterfaceMethods<ArithFastMathInterface>,
409+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
410+
let summary = "MAXVAL transformational intrinsic";
411+
let description = [{
412+
Maximum value(s) of an array.
413+
If DIM is absent, the result is a scalar.
414+
If DIM is present, the result is an array of rank n-1, where n is the rank of ARRAY.
415+
}];
416+
417+
let arguments = (ins
418+
AnyFortranArrayObject:$array,
419+
Optional<AnyIntegerType>:$dim,
420+
Optional<AnyFortranLogicalOrI1ArrayObject>:$mask,
421+
DefaultValuedAttr<Arith_FastMathAttr,
422+
"::mlir::arith::FastMathFlags::none">:$fastmath
423+
);
424+
425+
let results = (outs AnyFortranValue);
426+
427+
let assemblyFormat = [{
428+
$array (`dim` $dim^)? (`mask` $mask^)? attr-dict `:` functional-type(operands, results)
429+
}];
430+
431+
let hasVerifier = 1;
432+
}
433+
407434
def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
408435
DeclareOpInterfaceMethods<ArithFastMathInterface>,
409436
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

flang/lib/Lower/HlfirIntrinsics.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic {
8888
};
8989
using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>;
9090
using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>;
91+
using HlfirMaxvalLowering = HlfirReductionIntrinsic<hlfir::MaxvalOp, true>;
9192
using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
9293
using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;
9394

@@ -227,6 +228,10 @@ HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray,
227228
mlir::Type elementType = array.getEleTy();
228229
return hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
229230
/*polymorphic=*/false);
231+
} else if (auto resCharType =
232+
mlir::dyn_cast<fir::CharacterType>(stmtResultType)) {
233+
normalisedResult = hlfir::ExprType::get(
234+
builder.getContext(), hlfir::ExprType::Shape{}, resCharType, false);
230235
}
231236
return normalisedResult;
232237
}
@@ -348,6 +353,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
348353
if (name == "count")
349354
return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering,
350355
stmtResultType);
356+
if (name == "maxval")
357+
return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering,
358+
stmtResultType);
351359
if (mlir::isa<fir::CharacterType>(stmtResultType)) {
352360
if (name == "min")
353361
return HlfirCharExtremumLowering{builder, loc,

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,102 @@ void hlfir::ProductOp::getEffects(
749749
getIntrinsicEffects(getOperation(), effects);
750750
}
751751

752+
//===----------------------------------------------------------------------===//
753+
// CharacterReductionOp
754+
//===----------------------------------------------------------------------===//
755+
756+
template <typename CharacterReductionOp>
757+
static mlir::LogicalResult
758+
verifyCharacterReductionOp(CharacterReductionOp reductionOp) {
759+
mlir::Operation *op = reductionOp->getOperation();
760+
761+
auto results = op->getResultTypes();
762+
assert(results.size() == 1);
763+
764+
mlir::Value array = reductionOp->getArray();
765+
mlir::Value dim = reductionOp->getDim();
766+
mlir::Value mask = reductionOp->getMask();
767+
768+
fir::SequenceType arrayTy =
769+
hlfir::getFortranElementOrSequenceType(array.getType())
770+
.cast<fir::SequenceType>();
771+
mlir::Type numTy = arrayTy.getEleTy();
772+
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
773+
774+
if (mask) {
775+
fir::SequenceType maskSeq =
776+
hlfir::getFortranElementOrSequenceType(mask.getType())
777+
.dyn_cast<fir::SequenceType>();
778+
llvm::ArrayRef<int64_t> maskShape;
779+
780+
if (maskSeq)
781+
maskShape = maskSeq.getShape();
782+
783+
if (!maskShape.empty()) {
784+
if (maskShape.size() != arrayShape.size())
785+
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
786+
static_assert(fir::SequenceType::getUnknownExtent() ==
787+
hlfir::ExprType::getUnknownExtent());
788+
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
789+
for (std::size_t i = 0; i < arrayShape.size(); ++i) {
790+
int64_t arrayExtent = arrayShape[i];
791+
int64_t maskExtent = maskShape[i];
792+
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
793+
(maskExtent != unknownExtent))
794+
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
795+
}
796+
}
797+
}
798+
799+
auto resultExpr = results[0].cast<hlfir::ExprType>();
800+
mlir::Type resultType = resultExpr.getEleTy();
801+
assert(mlir::isa<fir::CharacterType>(resultType) &&
802+
"result must be character");
803+
804+
// Result is of the same type as ARRAY
805+
if (resultType != numTy)
806+
return reductionOp->emitOpError(
807+
"result must have the same element type as ARRAY argument");
808+
809+
if (arrayShape.size() > 1 && dim != nullptr) {
810+
if (!resultExpr.isArray())
811+
return reductionOp->emitOpError("result must be an array");
812+
llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
813+
// Result has rank n-1
814+
if (resultShape.size() != (arrayShape.size() - 1))
815+
return reductionOp->emitOpError(
816+
"result rank must be one less than ARRAY");
817+
} else if (!resultExpr.isScalar()) {
818+
return reductionOp->emitOpError("result must be scalar character");
819+
}
820+
return mlir::success();
821+
}
822+
823+
//===----------------------------------------------------------------------===//
824+
// MaxvalOp
825+
//===----------------------------------------------------------------------===//
826+
827+
mlir::LogicalResult hlfir::MaxvalOp::verify() {
828+
mlir::Operation *op = getOperation();
829+
830+
auto results = op->getResultTypes();
831+
assert(results.size() == 1);
832+
833+
auto resultExpr = mlir::dyn_cast<hlfir::ExprType>(results[0]);
834+
if (resultExpr && mlir::isa<fir::CharacterType>(resultExpr.getEleTy())) {
835+
return verifyCharacterReductionOp<hlfir::MaxvalOp *>(this);
836+
} else {
837+
return verifyNumericalReductionOp<hlfir::MaxvalOp *>(this);
838+
}
839+
}
840+
841+
void hlfir::MaxvalOp::getEffects(
842+
llvm::SmallVectorImpl<
843+
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
844+
&effects) {
845+
getIntrinsicEffects(getOperation(), effects);
846+
}
847+
752848
//===----------------------------------------------------------------------===//
753849
// SetLengthOp
754850
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
220220
opName = "sum";
221221
} else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) {
222222
opName = "product";
223+
} else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) {
224+
opName = "maxval";
223225
} else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
224226
opName = "any";
225227
} else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
@@ -238,7 +240,8 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
238240
llvm::SmallVector<fir::ExtendedValue, 0> args;
239241

240242
if constexpr (std::is_same_v<OP, hlfir::SumOp> ||
241-
std::is_same_v<OP, hlfir::ProductOp>) {
243+
std::is_same_v<OP, hlfir::ProductOp> ||
244+
std::is_same_v<OP, hlfir::MaxvalOp>) {
242245
args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
243246
} else {
244247
args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
@@ -259,6 +262,8 @@ using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>;
259262

260263
using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>;
261264

265+
using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;
266+
262267
using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;
263268

264269
using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
@@ -431,17 +436,19 @@ class LowerHLFIRIntrinsics
431436
mlir::ModuleOp module = this->getOperation();
432437
mlir::MLIRContext *context = &getContext();
433438
mlir::RewritePatternSet patterns(context);
434-
patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
435-
AllOpConversion, AnyOpConversion, SumOpConversion,
436-
ProductOpConversion, TransposeOpConversion,
437-
CountOpConversion, DotProductOpConversion>(context);
439+
patterns
440+
.insert<MatmulOpConversion, MatmulTransposeOpConversion,
441+
AllOpConversion, AnyOpConversion, SumOpConversion,
442+
ProductOpConversion, TransposeOpConversion, CountOpConversion,
443+
DotProductOpConversion, MaxvalOpConversion>(context);
438444
mlir::ConversionTarget target(*context);
439445
target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
440446
mlir::func::FuncDialect, fir::FIROpsDialect,
441447
hlfir::hlfirDialect>();
442448
target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
443449
hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp,
444-
hlfir::AllOp, hlfir::DotProductOp, hlfir::CountOp>();
450+
hlfir::AllOp, hlfir::DotProductOp, hlfir::CountOp,
451+
hlfir::MaxvalOp>();
445452
target.markUnknownOpDynamicallyLegal(
446453
[](mlir::Operation *) { return true; });
447454
if (mlir::failed(

flang/test/HLFIR/invalid.fir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,84 @@ func.func @bad_count4(%arg0: !hlfir.expr<?x!fir.logical<4>>, %arg1: i32) {
392392
%0 = hlfir.count %arg0 dim %arg1 : (!hlfir.expr<?x!fir.logical<4>>, i32) -> !fir.logical<4>
393393
}
394394

395+
// -----
396+
func.func @bad_maxval1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
397+
// expected-error@+1 {{'hlfir.maxval' op result must have the same element type as ARRAY argument}}
398+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> f32
399+
}
400+
401+
// -----
402+
func.func @bad_maxval2(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) {
403+
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
404+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
405+
}
406+
407+
// -----
408+
func.func @bad_maxval3(%arg0: !hlfir.expr<?x5x?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) {
409+
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
410+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x5x?xi32>, i32, !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
411+
}
412+
413+
// -----
414+
func.func @bad_maxval4(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
415+
// expected-error@+1 {{'hlfir.maxval' op result rank must be one less than ARRAY}}
416+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?xi32>
417+
}
418+
419+
// -----
420+
func.func @bad_maxval5(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
421+
// expected-error@+1 {{'hlfir.maxval' op result must be of numerical scalar type}}
422+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !fir.logical<4>
423+
}
424+
425+
// -----
426+
func.func @bad_maxval6(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32){
427+
// expected-error@+1 {{'hlfir.maxval' op result must be an array}}
428+
%0 = hlfir.maxval %arg0 dim %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<i32>
429+
}
430+
431+
// -----
432+
func.func @bad_maxval7(%arg0: !hlfir.expr<?xi32>){
433+
// expected-error@+1 {{'hlfir.maxval' op result must be of numerical scalar type}}
434+
%0 = hlfir.maxval %arg0 : (!hlfir.expr<?xi32>) -> !hlfir.expr<i32>
435+
}
436+
437+
// -----
438+
func.func @bad_maxval8(%arg0: !hlfir.expr<?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
439+
// expected-error@+1 {{'hlfir.maxval' op result must have the same element type as ARRAY argument}}
440+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x!fir.char<1,?>>, i32, !fir.box<!fir.logical<4>>) -> i32
441+
}
442+
443+
// -----
444+
func.func @bad_maxval9(%arg0: !hlfir.expr<?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) {
445+
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
446+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x!fir.char<1,?>>, i32, !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) -> !hlfir.expr<!fir.char<1,?>>
447+
}
448+
449+
// -----
450+
func.func @bad_maxval10(%arg0: !hlfir.expr<?x5x?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) {
451+
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
452+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x5x?x!fir.char<1,?>>, i32, !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) -> !hlfir.expr<!fir.char<1,?>>
453+
}
454+
455+
// -----
456+
func.func @bad_maxval11(%arg0: !hlfir.expr<?x?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
457+
// expected-error@+1 {{'hlfir.maxval' op result rank must be one less than ARRAY}}
458+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x?x!fir.char<1,?>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.char<1,?>>
459+
}
460+
461+
// -----
462+
func.func @bad_maxval12(%arg0: !hlfir.expr<?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
463+
// expected-error@+1 {{'hlfir.maxval' op result must be scalar character}}
464+
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x!fir.char<1,?>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x!fir.char<1,?>>
465+
}
466+
467+
// -----
468+
func.func @bad_maxval13(%arg0: !hlfir.expr<?x?x!fir.char<1,?>>, %arg1: i32){
469+
// expected-error@+1 {{'hlfir.maxval' op result must be an array}}
470+
%0 = hlfir.maxval %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.char<1,?>>, i32) -> !hlfir.expr<!fir.char<1,?>>
471+
}
472+
395473
// -----
396474
func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
397475
// expected-error@+1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}

0 commit comments

Comments
 (0)