-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[flang] Add reduction semantics to fir.do_loop #93934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2456,9 +2456,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder, | |
mlir::OperationState &result, mlir::Value lb, | ||
mlir::Value ub, mlir::Value step, bool unordered, | ||
bool finalCountValue, mlir::ValueRange iterArgs, | ||
mlir::ValueRange reduceOperands, | ||
llvm::ArrayRef<mlir::Attribute> reduceAttrs, | ||
llvm::ArrayRef<mlir::NamedAttribute> attributes) { | ||
result.addOperands({lb, ub, step}); | ||
result.addOperands(reduceOperands); | ||
result.addOperands(iterArgs); | ||
result.addAttribute(getOperandSegmentSizeAttr(), | ||
builder.getDenseI32ArrayAttr( | ||
{1, 1, 1, static_cast<int32_t>(reduceOperands.size()), | ||
static_cast<int32_t>(iterArgs.size())})); | ||
if (finalCountValue) { | ||
result.addTypes(builder.getIndexType()); | ||
result.addAttribute(getFinalValueAttrName(result.name), | ||
|
@@ -2477,6 +2484,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder, | |
if (unordered) | ||
result.addAttribute(getUnorderedAttrName(result.name), | ||
builder.getUnitAttr()); | ||
if (!reduceAttrs.empty()) | ||
result.addAttribute(getReduceAttrsAttrName(result.name), | ||
builder.getArrayAttr(reduceAttrs)); | ||
result.addAttributes(attributes); | ||
} | ||
|
||
|
@@ -2502,24 +2512,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, | |
if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) | ||
result.addAttribute("unordered", builder.getUnitAttr()); | ||
|
||
// Parse the reduction arguments. | ||
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands; | ||
llvm::SmallVector<mlir::Type> reduceArgTypes; | ||
if (succeeded(parser.parseOptionalKeyword("reduce"))) { | ||
// Parse reduction attributes and variables. | ||
llvm::SmallVector<ReduceAttr> attributes; | ||
if (failed(parser.parseCommaSeparatedList( | ||
mlir::AsmParser::Delimiter::Paren, [&]() { | ||
if (parser.parseAttribute(attributes.emplace_back()) || | ||
parser.parseArrow() || | ||
parser.parseOperand(reduceOperands.emplace_back()) || | ||
parser.parseColonType(reduceArgTypes.emplace_back())) | ||
return mlir::failure(); | ||
return mlir::success(); | ||
}))) | ||
return mlir::failure(); | ||
// Resolve input operands. | ||
for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) | ||
if (parser.resolveOperand(std::get<0>(operand_type), | ||
std::get<1>(operand_type), result.operands)) | ||
return mlir::failure(); | ||
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), | ||
attributes.end()); | ||
result.addAttribute(getReduceAttrsAttrName(result.name), | ||
builder.getArrayAttr(arrayAttr)); | ||
} | ||
|
||
// Parse the optional initial iteration arguments. | ||
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs; | ||
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; | ||
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands; | ||
llvm::SmallVector<mlir::Type> argTypes; | ||
bool prependCount = false; | ||
regionArgs.push_back(inductionVariable); | ||
|
||
if (succeeded(parser.parseOptionalKeyword("iter_args"))) { | ||
// Parse assignment list and results type list. | ||
if (parser.parseAssignmentList(regionArgs, operands) || | ||
if (parser.parseAssignmentList(regionArgs, iterOperands) || | ||
parser.parseArrowTypeList(result.types)) | ||
return mlir::failure(); | ||
if (result.types.size() == operands.size() + 1) | ||
if (result.types.size() == iterOperands.size() + 1) | ||
prependCount = true; | ||
// Resolve input operands. | ||
llvm::ArrayRef<mlir::Type> resTypes = result.types; | ||
for (auto operand_type : | ||
llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes)) | ||
for (auto operand_type : llvm::zip( | ||
iterOperands, prependCount ? resTypes.drop_front() : resTypes)) | ||
if (parser.resolveOperand(std::get<0>(operand_type), | ||
std::get<1>(operand_type), result.operands)) | ||
return mlir::failure(); | ||
|
@@ -2530,6 +2567,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, | |
prependCount = true; | ||
} | ||
|
||
// Set the operandSegmentSizes attribute | ||
result.addAttribute(getOperandSegmentSizeAttr(), | ||
builder.getDenseI32ArrayAttr( | ||
{1, 1, 1, static_cast<int32_t>(reduceOperands.size()), | ||
static_cast<int32_t>(iterOperands.size())})); | ||
|
||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) | ||
return mlir::failure(); | ||
|
||
|
@@ -2606,6 +2649,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() { | |
|
||
i++; | ||
} | ||
auto reduceAttrs = getReduceAttrsAttr(); | ||
if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) | ||
return emitOpError( | ||
"mismatch in number of reduction variables and reduction attributes"); | ||
return mlir::success(); | ||
} | ||
|
||
|
@@ -2615,6 +2662,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { | |
<< getUpperBound() << " step " << getStep(); | ||
if (getUnordered()) | ||
p << " unordered"; | ||
if (hasReduceOperands()) { | ||
p << " reduce("; | ||
auto attrs = getReduceAttrsAttr(); | ||
auto operands = getReduceOperands(); | ||
llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { | ||
p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " | ||
<< std::get<1>(it).getType(); | ||
}); | ||
p << ')'; | ||
printBlockTerminators = true; | ||
} | ||
Comment on lines
+2665
to
+2675
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switch position with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's a good point. Let's keep it like this. |
||
if (hasIterOperands()) { | ||
p << " iter_args("; | ||
auto regionArgs = getRegionIterArgs(); | ||
|
@@ -2628,8 +2686,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { | |
p << " -> " << getResultTypes(); | ||
printBlockTerminators = true; | ||
} | ||
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), | ||
{"unordered", "finalValue"}); | ||
p.printOptionalAttrDictWithKeyword( | ||
(*this)->getAttrs(), | ||
{"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"}); | ||
p << ' '; | ||
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, | ||
printBlockTerminators); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// Test the reduction semantics of fir.do_loop | ||
// RUN: fir-opt %s | FileCheck %s | ||
|
||
func.func @reduction() { | ||
%bound = arith.constant 10 : index | ||
%step = arith.constant 1 : index | ||
%sum = fir.alloca i32 | ||
// CHECK: %[[VAL_0:.*]] = fir.alloca i32 | ||
// CHECK: fir.do_loop %[[VAL_1:.*]] = %[[VAL_2:.*]] to %[[VAL_3:.*]] step %[[VAL_4:.*]] unordered reduce(#fir.reduce_attr<add> -> %[[VAL_0]] : !fir.ref<i32>) { | ||
fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) { | ||
%index = fir.convert %iv : (index) -> i32 | ||
%1 = fir.load %sum : !fir.ref<i32> | ||
%2 = arith.addi %index, %1 : i32 | ||
fir.store %2 to %sum : !fir.ref<i32> | ||
} | ||
return | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initArgs is more related to the control operands so maybe it's better to keep them together.