diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td index 0c34b640a5c9c..aedb6769186e9 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td +++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td @@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr< let cppNamespace = "fir"; } +def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum", + "intrinsic operations and functions supported by DO CONCURRENT REDUCE", + [ + I32BitEnumAttrCaseBit<"Add", 0, "add">, + I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">, + I32BitEnumAttrCaseBit<"AND", 2, "and">, + I32BitEnumAttrCaseBit<"OR", 3, "or">, + I32BitEnumAttrCaseBit<"EQV", 4, "eqv">, + I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">, + I32BitEnumAttrCaseBit<"MAX", 6, "max">, + I32BitEnumAttrCaseBit<"MIN", 7, "min">, + I32BitEnumAttrCaseBit<"IAND", 8, "iand">, + I32BitEnumAttrCaseBit<"IOR", 9, "ior">, + I32BitEnumAttrCaseBit<"EIOR", 10, "eior"> + ]> { + let separator = ", "; + let cppNamespace = "::fir"; + let printBitEnumPrimaryGroups = 1; +} + +def fir_ReduceAttr : fir_Attr<"Reduce"> { + let mnemonic = "reduce_attr"; + + let parameters = (ins + "ReduceOperationEnum":$reduce_operation + ); + + let assemblyFormat = "`<` $reduce_operation `>`"; +} + // mlir::SideEffects::Resource for modelling operations which add debugging information def DebuggingResource : Resource<"::fir::DebuggingResource">; diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 37fbd1f9692a4..e7da3af5485cc 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2125,8 +2125,8 @@ class region_Op traits = []> : let hasVerifier = 1; } -def fir_DoLoopOp : region_Op<"do_loop", - [DeclareOpInterfaceMethods]> { let summary = "generalized loop operation"; let description = [{ @@ -2156,9 +2156,11 @@ def fir_DoLoopOp : region_Op<"do_loop", Index:$lowerBound, Index:$upperBound, Index:$step, + Variadic:$reduceOperands, Variadic:$initArgs, OptionalAttr:$unordered, - OptionalAttr:$finalValue + OptionalAttr:$finalValue, + OptionalAttr:$reduceAttrs ); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -2169,6 +2171,8 @@ def fir_DoLoopOp : region_Op<"do_loop", "mlir::Value":$step, CArg<"bool", "false">:$unordered, CArg<"bool", "false">:$finalCountValue, CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs, + CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands, + CArg<"llvm::ArrayRef", "{}">:$reduceAttrs, CArg<"llvm::ArrayRef", "{}">:$attributes)> ]; @@ -2181,11 +2185,12 @@ def fir_DoLoopOp : region_Op<"do_loop", return getBody()->getArguments().drop_front(); } mlir::Operation::operand_range getIterOperands() { - return getOperands().drop_front(getNumControlOperands()); + return getOperands() + .drop_front(getNumControlOperands() + getNumReduceOperands()); } llvm::MutableArrayRef getInitsMutable() { - return - getOperation()->getOpOperands().drop_front(getNumControlOperands()); + return getOperation()->getOpOperands() + .drop_front(getNumControlOperands() + getNumReduceOperands()); } void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); } @@ -2200,11 +2205,25 @@ def fir_DoLoopOp : region_Op<"do_loop", unsigned getNumControlOperands() { return 3; } /// Does the operation hold operands for loop-carried values bool hasIterOperands() { - return (*this)->getNumOperands() > getNumControlOperands(); + return getNumIterOperands() > 0; + } + /// Does the operation hold operands for reduction variables + bool hasReduceOperands() { + return getNumReduceOperands() > 0; + } + /// Get Number of variadic operands + unsigned getNumOperands(unsigned idx) { + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); + return static_cast(segments[idx]); + } + // Get Number of reduction operands + unsigned getNumReduceOperands() { + return getNumOperands(3); } /// Get Number of loop-carried values unsigned getNumIterOperands() { - return (*this)->getNumOperands() - getNumControlOperands(); + return getNumOperands(4); } /// Get the body of the loop diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp index 2faba63dfba07..a0202a0159228 100644 --- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp +++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp @@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr, void FIROpsDialect::registerAttributes() { addAttributes(); + LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr, + SubclassAttr, UpperBoundAttr>(); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index b530a9dc1bcc4..75ca738211abe 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2456,9 +2456,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value lb, mlir::Value ub, mlir::Value step, bool unordered, bool finalCountValue, mlir::ValueRange iterArgs, + mlir::ValueRange reduceOperands, + llvm::ArrayRef reduceAttrs, llvm::ArrayRef attributes) { result.addOperands({lb, ub, step}); + result.addOperands(reduceOperands); result.addOperands(iterArgs); + result.addAttribute(getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, static_cast(reduceOperands.size()), + static_cast(iterArgs.size())})); if (finalCountValue) { result.addTypes(builder.getIndexType()); result.addAttribute(getFinalValueAttrName(result.name), @@ -2477,6 +2484,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder, if (unordered) result.addAttribute(getUnorderedAttrName(result.name), builder.getUnitAttr()); + if (!reduceAttrs.empty()) + result.addAttribute(getReduceAttrsAttrName(result.name), + builder.getArrayAttr(reduceAttrs)); result.addAttributes(attributes); } @@ -2502,24 +2512,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) result.addAttribute("unordered", builder.getUnitAttr()); + // Parse the reduction arguments. + llvm::SmallVector reduceOperands; + llvm::SmallVector reduceArgTypes; + if (succeeded(parser.parseOptionalKeyword("reduce"))) { + // Parse reduction attributes and variables. + llvm::SmallVector attributes; + if (failed(parser.parseCommaSeparatedList( + mlir::AsmParser::Delimiter::Paren, [&]() { + if (parser.parseAttribute(attributes.emplace_back()) || + parser.parseArrow() || + parser.parseOperand(reduceOperands.emplace_back()) || + parser.parseColonType(reduceArgTypes.emplace_back())) + return mlir::failure(); + return mlir::success(); + }))) + return mlir::failure(); + // Resolve input operands. + for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) + if (parser.resolveOperand(std::get<0>(operand_type), + std::get<1>(operand_type), result.operands)) + return mlir::failure(); + llvm::SmallVector arrayAttr(attributes.begin(), + attributes.end()); + result.addAttribute(getReduceAttrsAttrName(result.name), + builder.getArrayAttr(arrayAttr)); + } + // Parse the optional initial iteration arguments. llvm::SmallVector regionArgs; - llvm::SmallVector operands; + llvm::SmallVector iterOperands; llvm::SmallVector argTypes; bool prependCount = false; regionArgs.push_back(inductionVariable); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { // Parse assignment list and results type list. - if (parser.parseAssignmentList(regionArgs, operands) || + if (parser.parseAssignmentList(regionArgs, iterOperands) || parser.parseArrowTypeList(result.types)) return mlir::failure(); - if (result.types.size() == operands.size() + 1) + if (result.types.size() == iterOperands.size() + 1) prependCount = true; // Resolve input operands. llvm::ArrayRef resTypes = result.types; - for (auto operand_type : - llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes)) + for (auto operand_type : llvm::zip( + iterOperands, prependCount ? resTypes.drop_front() : resTypes)) if (parser.resolveOperand(std::get<0>(operand_type), std::get<1>(operand_type), result.operands)) return mlir::failure(); @@ -2530,6 +2567,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, prependCount = true; } + // Set the operandSegmentSizes attribute + result.addAttribute(getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, static_cast(reduceOperands.size()), + static_cast(iterOperands.size())})); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return mlir::failure(); @@ -2606,6 +2649,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() { i++; } + auto reduceAttrs = getReduceAttrsAttr(); + if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) + return emitOpError( + "mismatch in number of reduction variables and reduction attributes"); return mlir::success(); } @@ -2615,6 +2662,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { << getUpperBound() << " step " << getStep(); if (getUnordered()) p << " unordered"; + if (hasReduceOperands()) { + p << " reduce("; + auto attrs = getReduceAttrsAttr(); + auto operands = getReduceOperands(); + llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { + p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " + << std::get<1>(it).getType(); + }); + p << ')'; + printBlockTerminators = true; + } if (hasIterOperands()) { p << " iter_args("; auto regionArgs = getRegionIterArgs(); @@ -2628,8 +2686,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { p << " -> " << getResultTypes(); printBlockTerminators = true; } - p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), - {"unordered", "finalValue"}); + p.printOptionalAttrDictWithKeyword( + (*this)->getAttrs(), + {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"}); p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, printBlockTerminators); diff --git a/flang/test/Fir/loop03.fir b/flang/test/Fir/loop03.fir new file mode 100644 index 0000000000000..b88dcaf8639be --- /dev/null +++ b/flang/test/Fir/loop03.fir @@ -0,0 +1,17 @@ +// Test the reduction semantics of fir.do_loop +// RUN: fir-opt %s | FileCheck %s + +func.func @reduction() { + %bound = arith.constant 10 : index + %step = arith.constant 1 : index + %sum = fir.alloca i32 +// CHECK: %[[VAL_0:.*]] = fir.alloca i32 +// CHECK: fir.do_loop %[[VAL_1:.*]] = %[[VAL_2:.*]] to %[[VAL_3:.*]] step %[[VAL_4:.*]] unordered reduce(#fir.reduce_attr -> %[[VAL_0]] : !fir.ref) { + fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr -> %sum : !fir.ref) { + %index = fir.convert %iv : (index) -> i32 + %1 = fir.load %sum : !fir.ref + %2 = arith.addi %index, %1 : i32 + fir.store %2 to %sum : !fir.ref + } + return +}