Skip to content

Cherrypicked atomic operation based changes from llvm main #1570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,12 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,

def YieldOp : OpenMP_Op<"yield",
[NoSideEffect, ReturnLike, Terminator,
ParentOneOf<["WsLoopOp", "ReductionDeclareOp"]>]> {
ParentOneOf<["WsLoopOp", "ReductionDeclareOp", "AtomicUpdateOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"omp.yield" yields SSA values from the OpenMP dialect op region and
terminates the region. The semantics of how the values are yielded is
defined by the parent operation.
If "omp.yield" has any operands, the operands must match the parent
operation's results.
}];

let arguments = (ins Variadic<AnyType>:$results);
Expand Down Expand Up @@ -559,11 +557,11 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
// value of the clause) here decomposes handling of this construct into a
// two-step process.

def AtomicReadOp : OpenMP_Op<"atomic.read"> {
let arguments = (ins OpenMP_PointerLikeType:$address,
def AtomicReadOp : OpenMP_Op<"atomic.read", [AllTypesMatch<["x", "v"]>]> {
let arguments = (ins OpenMP_PointerLikeType:$x,
OpenMP_PointerLikeType:$v,
DefaultValuedAttr<I64Attr, "0">:$hint,
OptionalAttr<MemoryOrderKind>:$memory_order);
let results = (outs AnyType);
let parser = [{ return parseAtomicReadOp(parser, result); }];
let printer = [{ return printAtomicReadOp(p, *this); }];
let verifier = [{ return verifyAtomicReadOp(*this); }];
Expand Down Expand Up @@ -606,18 +604,25 @@ def AtomicBinOpKindAttr : I64EnumAttr<
let symbolToStringFnName = "AtomicBinOpKindToString";
}

def AtomicUpdateOp : OpenMP_Op<"atomic.update"> {
def AtomicUpdateOp : OpenMP_Op<"atomic.update", [SingleBlockImplicitTerminator<"YieldOp">]> {
let arguments = (ins OpenMP_PointerLikeType:$x,
AnyType:$expr,
UnitAttr:$isXBinopExpr,
AtomicBinOpKindAttr:$binop,
DefaultValuedAttr<I64Attr, "0">:$hint,
OptionalAttr<MemoryOrderKind>:$memory_order);
let regions = (region SizedRegion<1>:$region);
let parser = [{ return parseAtomicUpdateOp(parser, result); }];
let printer = [{ return printAtomicUpdateOp(p, *this); }];
let verifier = [{ return verifyAtomicUpdateOp(*this); }];
}

def AtomicCaptureOp : OpenMP_Op<"atomic.capture", [SingleBlockImplicitTerminator<"TerminatorOp">]>{
let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$hint,
OptionalAttr<MemoryOrderKind>:$memory_order);
let regions = (region SizedRegion<1>:$region);
let parser = [{ return parseAtomicCaptureOp(parser, result); }];
let printer = [{ return printAtomicCaptureOp(p, *this); }];
let verifier = [{ return verifyAtomicCaptureOp(*this); }];
}

//===----------------------------------------------------------------------===//
// 2.19.5.7 declare reduction Directive
//===----------------------------------------------------------------------===//
Expand Down
158 changes: 101 additions & 57 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,32 +1227,28 @@ static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
/// address ::= operand `:` type
static ParseResult parseAtomicReadOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType address;
OpAsmParser::OperandType x, v;
Type addressType;
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
SmallVector<int> segments;

if (parser.parseOperand(address) ||
if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
parseClauses(parser, result, clauses, segments) ||
parser.parseColonType(addressType) ||
parser.resolveOperand(address, addressType, result.operands))
parser.resolveOperand(x, addressType, result.operands) ||
parser.resolveOperand(v, addressType, result.operands))
return failure();

SmallVector<Type> resultType;
if (parser.parseArrowTypeList(resultType))
return failure();
result.addTypes(resultType);
return success();
}

/// Printer for AtomicReadOp
static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
p << " " << op.address() << " ";
p << " " << op.v() << " = " << op.x() << " ";
if (op.memory_order())
p << "memory_order(" << op.memory_order().getValue() << ") ";
if (op.hintAttr())
printSynchronizationHint(p << " ", op, op.hintAttr());
p << ": " << op.address().getType() << " -> " << op.getType();
p << ": " << op.x().getType();
return;
}

Expand All @@ -1264,6 +1260,9 @@ static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
return op.emitError(
"memory-order must not be acq_rel or release for atomic reads");
}
if (op.x() == op.v())
return op.emitError(
"read and write must not be to the same location for atomic reads");
return verifySynchronizationHint(op, op.hint());
}

Expand All @@ -1284,7 +1283,7 @@ static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
SmallVector<int> segments;

if (parser.parseOperand(address) || parser.parseComma() ||
if (parser.parseOperand(address) || parser.parseEqual() ||
parser.parseOperand(value) ||
parseClauses(parser, result, clauses, segments) ||
parser.parseColonType(addrType) || parser.parseComma() ||
Expand All @@ -1297,7 +1296,7 @@ static ParseResult parseAtomicWriteOp(OpAsmParser &parser,

/// Printer for AtomicWriteOp
static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
p << " " << op.address() << ", " << op.value() << " ";
p << " " << op.address() << " = " << op.value() << " ";
if (op.memory_order())
p << "memory_order(" << op.memory_order() << ") ";
if (op.hintAttr())
Expand Down Expand Up @@ -1328,61 +1327,28 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
SmallVector<int> segments;
OpAsmParser::OperandType x, y, z;
Type xType, exprType;
StringRef binOp;

// x = y `op` z : xtype, exprtype
if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) ||
parser.parseKeyword(&binOp) || parser.parseOperand(z) ||
parseClauses(parser, result, clauses, segments) || parser.parseColon() ||
parser.parseType(xType) || parser.parseComma() ||
parser.parseType(exprType) ||
parser.resolveOperand(x, xType, result.operands)) {
OpAsmParser::OperandType x, expr;
Type xType;

if (parseClauses(parser, result, clauses, segments) ||
parser.parseOperand(x) || parser.parseColon() ||
parser.parseType(xType) ||
parser.resolveOperand(x, xType, result.operands) ||
parser.parseRegion(*result.addRegion())) {
return failure();
}

auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper());
if (!binOpEnum)
return parser.emitError(parser.getNameLoc())
<< "invalid atomic bin op in atomic update\n";
auto attr =
parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue());
result.addAttribute("binop", attr);

OpAsmParser::OperandType expr;
if (x.name == y.name && x.number == y.number) {
expr = z;
result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr());
} else if (x.name == z.name && x.number == z.number) {
expr = y;
} else {
return parser.emitError(parser.getNameLoc())
<< "atomic update variable " << x.name
<< " not found in the RHS of the assignment statement in an"
" atomic.update operation";
}
return parser.resolveOperand(expr, exprType, result.operands);
return success();
}

/// Printer for AtomicUpdateOp
static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) {
p << " " << op.x() << " = ";
Value y, z;
if (op.isXBinopExpr()) {
y = op.x();
z = op.expr();
} else {
y = op.expr();
z = op.x();
}
p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z
<< " ";
p << " ";
if (op.memory_order())
p << "memory_order(" << op.memory_order() << ") ";
if (op.hintAttr())
printSynchronizationHint(p, op, op.hintAttr());
p << ": " << op.x().getType() << ", " << op.expr().getType();
p << op.x() << " : " << op.x().getType();
p.printRegion(op.region());
}

/// Verifier for AtomicUpdateOp
Expand All @@ -1393,6 +1359,84 @@ static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) {
return op.emitError(
"memory-order must not be acq_rel or acquire for atomic updates");
}
if (op.region().getNumArguments() != 1)
return op.emitError("the region must accept exactly one argument");

if (op.x().getType().cast<PointerLikeType>().getElementType() !=
op.region().getArgument(0).getType()) {
return op.emitError(
"the type of the operand must be a pointer type whose "
"element type is the same as that of the region argument");
}

YieldOp yieldOp = *op.region().getOps<YieldOp>().begin();
if (yieldOp.results().size() != 1)
return op.emitError("only updated value must be returned");
if (yieldOp.results().front().getType() !=
op.region().getArgument(0).getType())
return op.emitError("input and yielded value must have the same type");
return success();
}

//===----------------------------------------------------------------------===//
// AtomicCaptureOp
//===----------------------------------------------------------------------===//

/// Parser for AtomicCaptureOp
static LogicalResult parseAtomicCaptureOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
SmallVector<int> segments;
if (parseClauses(parser, result, clauses, segments) ||
parser.parseRegion(*result.addRegion()))
return failure();
return success();
}

/// Printer for AtomicCaptureOp
static void printAtomicCaptureOp(OpAsmPrinter &p, AtomicCaptureOp op) {
if (op.memory_order())
p << "memory_order(" << op.memory_order() << ") ";
if (op.hintAttr())
printSynchronizationHint(p, op, op.hintAttr());
p.printRegion(op.region());
}

/// Verifier for AtomicCaptureOp
static LogicalResult verifyAtomicCaptureOp(AtomicCaptureOp op) {
Block::OpListType &ops = op.region().front().getOperations();
if (ops.size() != 3)
return emitError(op.getLoc())
<< "expected three operations in omp.atomic.capture region (one"
" terminator, and two atomic ops";
auto &firstOp = ops.front();
auto &secondOp = *ops.getNextNode(firstOp);
auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);

if (!((firstUpdateStmt && secondReadStmt) ||
(firstReadStmt && secondUpdateStmt) ||
(firstReadStmt && secondWriteStmt)))
return emitError(ops.front().getLoc())
<< "invalid sequence of operations in the capture region";
if (firstUpdateStmt && secondReadStmt &&
firstUpdateStmt.x() != secondReadStmt.x())
return emitError(firstUpdateStmt.getLoc())
<< "updated variable in omp.atomic.update must be captured in "
"second operation";
if (firstReadStmt && secondUpdateStmt &&
firstReadStmt.x() != secondUpdateStmt.x())
return emitError(firstReadStmt.getLoc())
<< "captured variable in omp.atomic.read must be updated in "
"second operation";
if (firstReadStmt && secondWriteStmt &&
firstReadStmt.x() != secondWriteStmt.address())
return emitError(firstReadStmt.getLoc())
<< "captured variable in omp.atomic.read must be updated in "
"second operation";
return success();
}

Expand Down
Loading