Skip to content

[flang][hlfir] Add hlfir.maxval intrinsic #65705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,33 @@ def hlfir_CountOp : hlfir_Op<"count", [AttrSizedOperandSegments, DeclareOpInterf
let hasVerifier = 1;
}

def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "MAXVAL transformational intrinsic";
let description = [{
Maximum value(s) of an array.
If DIM is absent, the result is a scalar.
If DIM is present, the result is an array of rank n-1, where n is the rank of ARRAY.
}];

let arguments = (ins
AnyFortranArrayObject:$array,
Optional<AnyIntegerType>:$dim,
Optional<AnyFortranLogicalOrI1ArrayObject>:$mask,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath
);

let results = (outs AnyFortranValue);

let assemblyFormat = [{
$array (`dim` $dim^)? (`mask` $mask^)? attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Expand Down
8 changes: 8 additions & 0 deletions flang/lib/Lower/HlfirIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic {
};
using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>;
using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>;
using HlfirMaxvalLowering = HlfirReductionIntrinsic<hlfir::MaxvalOp, true>;
using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;

Expand Down Expand Up @@ -227,6 +228,10 @@ HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray,
mlir::Type elementType = array.getEleTy();
return hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
/*polymorphic=*/false);
} else if (auto resCharType =
mlir::dyn_cast<fir::CharacterType>(stmtResultType)) {
normalisedResult = hlfir::ExprType::get(
builder.getContext(), hlfir::ExprType::Shape{}, resCharType, false);
}
return normalisedResult;
}
Expand Down Expand Up @@ -348,6 +353,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
if (name == "count")
return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering,
stmtResultType);
if (name == "maxval")
return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering,
stmtResultType);
if (mlir::isa<fir::CharacterType>(stmtResultType)) {
if (name == "min")
return HlfirCharExtremumLowering{builder, loc,
Expand Down
96 changes: 96 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,102 @@ void hlfir::ProductOp::getEffects(
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// CharacterReductionOp
//===----------------------------------------------------------------------===//

template <typename CharacterReductionOp>
static mlir::LogicalResult
verifyCharacterReductionOp(CharacterReductionOp reductionOp) {
mlir::Operation *op = reductionOp->getOperation();

auto results = op->getResultTypes();
assert(results.size() == 1);

mlir::Value array = reductionOp->getArray();
mlir::Value dim = reductionOp->getDim();
mlir::Value mask = reductionOp->getMask();

fir::SequenceType arrayTy =
hlfir::getFortranElementOrSequenceType(array.getType())
.cast<fir::SequenceType>();
mlir::Type numTy = arrayTy.getEleTy();
llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();

if (mask) {
fir::SequenceType maskSeq =
hlfir::getFortranElementOrSequenceType(mask.getType())
.dyn_cast<fir::SequenceType>();
llvm::ArrayRef<int64_t> maskShape;

if (maskSeq)
maskShape = maskSeq.getShape();

if (!maskShape.empty()) {
if (maskShape.size() != arrayShape.size())
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
static_assert(fir::SequenceType::getUnknownExtent() ==
hlfir::ExprType::getUnknownExtent());
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
for (std::size_t i = 0; i < arrayShape.size(); ++i) {
int64_t arrayExtent = arrayShape[i];
int64_t maskExtent = maskShape[i];
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
(maskExtent != unknownExtent))
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
}
}
}

auto resultExpr = results[0].cast<hlfir::ExprType>();
mlir::Type resultType = resultExpr.getEleTy();
assert(mlir::isa<fir::CharacterType>(resultType) &&
"result must be character");

// Result is of the same type as ARRAY
if (resultType != numTy)
return reductionOp->emitOpError(
"result must have the same element type as ARRAY argument");

if (arrayShape.size() > 1 && dim != nullptr) {
if (!resultExpr.isArray())
return reductionOp->emitOpError("result must be an array");
llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
// Result has rank n-1
if (resultShape.size() != (arrayShape.size() - 1))
return reductionOp->emitOpError(
"result rank must be one less than ARRAY");
} else if (!resultExpr.isScalar()) {
return reductionOp->emitOpError("result must be scalar character");
}
return mlir::success();
}

//===----------------------------------------------------------------------===//
// MaxvalOp
//===----------------------------------------------------------------------===//

mlir::LogicalResult hlfir::MaxvalOp::verify() {
mlir::Operation *op = getOperation();

auto results = op->getResultTypes();
assert(results.size() == 1);

auto resultExpr = mlir::dyn_cast<hlfir::ExprType>(results[0]);
if (resultExpr && mlir::isa<fir::CharacterType>(resultExpr.getEleTy())) {
return verifyCharacterReductionOp<hlfir::MaxvalOp *>(this);
} else {
return verifyNumericalReductionOp<hlfir::MaxvalOp *>(this);
}
}

void hlfir::MaxvalOp::getEffects(
llvm::SmallVectorImpl<
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
&effects) {
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// SetLengthOp
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 13 additions & 6 deletions flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
opName = "sum";
} else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) {
opName = "product";
} else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) {
opName = "maxval";
} else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
opName = "any";
} else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
Expand All @@ -238,7 +240,8 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
llvm::SmallVector<fir::ExtendedValue, 0> args;

if constexpr (std::is_same_v<OP, hlfir::SumOp> ||
std::is_same_v<OP, hlfir::ProductOp>) {
std::is_same_v<OP, hlfir::ProductOp> ||
std::is_same_v<OP, hlfir::MaxvalOp>) {
args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
} else {
args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
Expand All @@ -259,6 +262,8 @@ using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>;

using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>;

using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;

using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;

using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
Expand Down Expand Up @@ -431,17 +436,19 @@ class LowerHLFIRIntrinsics
mlir::ModuleOp module = this->getOperation();
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
AllOpConversion, AnyOpConversion, SumOpConversion,
ProductOpConversion, TransposeOpConversion,
CountOpConversion, DotProductOpConversion>(context);
patterns
.insert<MatmulOpConversion, MatmulTransposeOpConversion,
AllOpConversion, AnyOpConversion, SumOpConversion,
ProductOpConversion, TransposeOpConversion, CountOpConversion,
DotProductOpConversion, MaxvalOpConversion>(context);
mlir::ConversionTarget target(*context);
target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect, fir::FIROpsDialect,
hlfir::hlfirDialect>();
target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp,
hlfir::AllOp, hlfir::DotProductOp, hlfir::CountOp>();
hlfir::AllOp, hlfir::DotProductOp, hlfir::CountOp,
hlfir::MaxvalOp>();
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });
if (mlir::failed(
Expand Down
78 changes: 78 additions & 0 deletions flang/test/HLFIR/invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,84 @@ func.func @bad_count4(%arg0: !hlfir.expr<?x!fir.logical<4>>, %arg1: i32) {
%0 = hlfir.count %arg0 dim %arg1 : (!hlfir.expr<?x!fir.logical<4>>, i32) -> !fir.logical<4>
}

// -----
func.func @bad_maxval1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.maxval' op result must have the same element type as ARRAY argument}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> f32
}

// -----
func.func @bad_maxval2(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) {
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
}

// -----
func.func @bad_maxval3(%arg0: !hlfir.expr<?x5x?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) {
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x5x?xi32>, i32, !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
}

// -----
func.func @bad_maxval4(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.maxval' op result rank must be one less than ARRAY}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?xi32>
}

// -----
func.func @bad_maxval5(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.maxval' op result must be of numerical scalar type}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !fir.logical<4>
}

// -----
func.func @bad_maxval6(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32){
// expected-error@+1 {{'hlfir.maxval' op result must be an array}}
%0 = hlfir.maxval %arg0 dim %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<i32>
}

// -----
func.func @bad_maxval7(%arg0: !hlfir.expr<?xi32>){
// expected-error@+1 {{'hlfir.maxval' op result must be of numerical scalar type}}
%0 = hlfir.maxval %arg0 : (!hlfir.expr<?xi32>) -> !hlfir.expr<i32>
}

// -----
func.func @bad_maxval8(%arg0: !hlfir.expr<?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.maxval' op result must have the same element type as ARRAY argument}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x!fir.char<1,?>>, i32, !fir.box<!fir.logical<4>>) -> i32
}

// -----
func.func @bad_maxval9(%arg0: !hlfir.expr<?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) {
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x!fir.char<1,?>>, i32, !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) -> !hlfir.expr<!fir.char<1,?>>
}

// -----
func.func @bad_maxval10(%arg0: !hlfir.expr<?x5x?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) {
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x5x?x!fir.char<1,?>>, i32, !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) -> !hlfir.expr<!fir.char<1,?>>
}

// -----
func.func @bad_maxval11(%arg0: !hlfir.expr<?x?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.maxval' op result rank must be one less than ARRAY}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x?x!fir.char<1,?>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.char<1,?>>
}

// -----
func.func @bad_maxval12(%arg0: !hlfir.expr<?x!fir.char<1,?>>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.maxval' op result must be scalar character}}
%0 = hlfir.maxval %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x!fir.char<1,?>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x!fir.char<1,?>>
}

// -----
func.func @bad_maxval13(%arg0: !hlfir.expr<?x?x!fir.char<1,?>>, %arg1: i32){
// expected-error@+1 {{'hlfir.maxval' op result must be an array}}
%0 = hlfir.maxval %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.char<1,?>>, i32) -> !hlfir.expr<!fir.char<1,?>>
}

// -----
func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}
Expand Down
Loading