diff --git a/mlir/docs/Dialects/OpenACCDialect.md b/mlir/docs/Dialects/OpenACCDialect.md index 2f1bb194a167d..9e85d66ba340f 100755 --- a/mlir/docs/Dialects/OpenACCDialect.md +++ b/mlir/docs/Dialects/OpenACCDialect.md @@ -274,28 +274,96 @@ reference counters are zero, a delete action is performed. ### Types -There are a few acc dialect type categories to describe: -* type of acc data clause operation input `varPtr` - - The type of `varPtr` must be pointer-like. This is done by - attaching the `PointerLikeType` interface to the appropriate MLIR - type. Although memory/storage concept is a lower level abstraction, - it is useful because the OpenACC model distinguishes between host - and device memory explicitly - and the mapping between the two is - done through pointers. Thus, by explicitly requiring it in the - dialect, the appropriate language frontend must create storage or - use type that satisfies the mapping constraint. +Since the `acc dialect` is meant to be used alongside other dialects which +represent the source language, appropriate use of types and type interfaces is +key to ensuring compatibility. This section describes those considerations. + +#### Data Clause Operation Types + +Data clause operations (eg. `acc.copyin`) rely on the following type +considerations: +* type of acc data clause operation input `var` + - The type of `var` must be one with `PointerLikeType` or `MappableType` + interfaces attached. The first, `PointerLikeType`, is useful because + the OpenACC memory model distinguishes between host and device memory + explicitly - and the mapping between the two is done through pointers. Thus, + by explicitly requiring it in the dialect, the appropriate language + frontend must create storage or use type that satisfies the mapping + constraint. The second possibility, `MappableType` was added because + memory/storage concept is a lower level abstraction and not all dialects + choose to use a pointer abstraction especially in the case where semantics + are more complex (such as `fir.box` which represents Fortran descriptors + and is defined in the `fir` dialect used from `flang`). * type of result of acc data clause operations - The type of the acc data clause operation is exactly the same as - `varPtr`. This was done intentionally instead of introducing an - `acc.ref/ptr` type so that IR compatibility and the dialect's + `var`. This was done intentionally instead of introducing specific `acc` + output types so that so that IR compatibility and the dialect's existing strong type checking can be maintained. This is needed since the `acc` dialect must live within another dialect whose type - system is unknown to it. The only constraint is that the appropriate - dialect type must use the `PointerLikeType` interface. + system is unknown to it. +* variable type captured in `varType` + - When `var`'s type is `PointerLikeType`, the actual type of the target + may be lost. More specifically, dialects like `llvm` which use opaque + pointers, do not record the target variable's type. The use of this field + bridges this gap. * type of decomposed clauses - Decomposed clauses, such as `acc.bounds` and `acc.declare_enter` produce types to allow their results to be used only in specific - operations. + operations. These are synthetic types solely used for proper IR + construction. + +#### Pointer-Like Requirement + +The need to have pointer-type requirement in the acc dialect stems from +a few different aspects: +- Existing dialects like `hlfir`, `fir`, `cir`, `llvm` use a pointer +representation for variables. +- Reference counters (for data clauses) are described in terms of +memory. In OpenACC spec 3.3 in section 2.6.7. It says: "A structured reference +counter is incremented when entering each data or compute region that contain an +explicit data clause or implicitly-determined data attributes for that section +of memory". This implies addressability of memory. +- Attach semantics (2.6.8 attachment counter) are specified using +"address" terminology: "The attachment counter for a pointer is set to +one whenever the pointer is attached to new target address, and +incremented whenever an attach action for that pointer is performed for +the same target address. + +#### Type Interfaces + +The `acc` dialect describes two different type interfaces which must be +implemented and attached to the source dialect's types in order to allow use +of data clause operations (eg. `acc.copyin`). They are as follows: +* `PointerLikeType` + - The idea behind this interface is that variables end up being represented + as pointers in many dialects. More specifically, `fir`, `cir`, `llvm` + represent user declared local variables with some dialect specific form of + `alloca` operation which produce pointers. Globals, similarly, are referred by + their address through some form of `address_of` operation. Additionally, an + implementation for OpenACC runtime needs to distinguish between device and + host memory - also typically done by talking about pointers. So this type + interface requirement fits in naturally with OpenACC specification. Data + mapping operation semantics can often be simply described by a pointer and + size of the data it points to. +* `MappableType` + - This interface was introduced because the `PointerLikeType` requirement + cannot represent cases when the source dialect does not use pointers. Also, + some cases, such as Fortran descriptor-backed arrays and Fortran optional + arguments, require decomposition into multiple steps. For example, in the + descriptor case, mapping of descriptor is needed, mapping of the data, and + implicit attach into device descriptor. In order to allow capturing all of + this complexity with a single data clause operation, the `MappableType` + interface was introduced. This is consistent with the dialect's goals + including being "able to regenerate the semantic equivalent of the user + pragmas". + +The intent is that a dialect's type system implements one of these two +interfaces. And to be precise, a type should only implement one or the other +(and not both) - since keeping them separate avoids ambiguity on what actually +needs mapped. When `var` is `PointerLikeType`, the assumption is that the data +pointed-to will be mapped. If the pointer-like type also implemented +`MappableType` interface, it becomes ambiguous whether the data pointed to or +the pointer itself is being mapped. ### Recipes diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index cda07d6a91364..748cb7f28fc8c 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -25,6 +25,7 @@ #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.h.inc" #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.h.inc" #include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -83,16 +84,31 @@ namespace acc { /// combined and the final mapping value would be 5 (4 | 1). enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4 }; -/// Used to obtain the `varPtr` from a data clause operation. +/// Used to obtain the `var` from a data clause operation. /// Returns empty value if not a data clause operation or is a data exit -/// operation with no `varPtr`. -mlir::Value getVarPtr(mlir::Operation *accDataClauseOp); - -/// Used to obtain the `accPtr` from a data clause operation. -/// When a data entry operation, it obtains its result `accPtr` value. -/// If a data exit operation, it obtains its operand `accPtr` value. +/// operation with no `var`. +mlir::Value getVar(mlir::Operation *accDataClauseOp); + +/// Used to obtain the `var` from a data clause operation if it implements +/// `PointerLikeType`. +mlir::TypedValue +getVarPtr(mlir::Operation *accDataClauseOp); + +/// Used to obtains the `varType` from a data clause operation which records +/// the type of variable. When `var` is `PointerLikeType`, this returns +/// the type of the pointer target. +mlir::Type getVarType(mlir::Operation *accDataClauseOp); + +/// Used to obtain the `accVar` from a data clause operation. +/// When a data entry operation, it obtains its result `accVar` value. +/// If a data exit operation, it obtains its operand `accVar` value. /// Returns empty value if not a data clause operation. -mlir::Value getAccPtr(mlir::Operation *accDataClauseOp); +mlir::Value getAccVar(mlir::Operation *accDataClauseOp); + +/// Used to obtain the `accVar` from a data clause operation if it implements +/// `PointerLikeType`. +mlir::TypedValue +getAccPtr(mlir::Operation *accDataClauseOp); /// Used to obtain the `varPtrPtr` from a data clause operation. /// Returns empty value if not a data clause operation. @@ -136,6 +152,18 @@ mlir::ValueRange getDataOperands(mlir::Operation *accOp); /// Used to get a mutable range iterating over the data operands. mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp); +/// Used to check whether the provided `type` implements the `PointerLikeType` +/// interface. +inline bool isPointerLikeType(mlir::Type type) { + return mlir::isa(type); +} + +/// Used to check whether the provided `type` implements the `MappableType` +/// interface. +inline bool isMappableType(mlir::Type type) { + return mlir::isa(type); +} + /// Used to obtain the attribute name for declare. static constexpr StringLiteral getDeclareAttrName() { return StringLiteral("acc.declare"); diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 3ac265ac68756..a47f70b168066 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -70,7 +70,14 @@ def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>; // Simple alias to pointer-like interface to reduce verbosity. def OpenACC_PointerLikeType : TypeAlias; + "pointer-like type">; +def OpenACC_MappableType : TypeAlias; + +def OpenACC_AnyPointerOrMappableLike : TypeConstraint, "any pointer or mappable">; +def OpenACC_AnyPointerOrMappableType : Type; // Define the OpenACC data clauses. There are a few cases where a modifier // is used, like create(zero), copyin(readonly), and copyout(zero). Since in @@ -353,7 +360,8 @@ def OpenACC_DataBoundsOp : OpenACC_Op<"bounds", build($_builder, $_state, ::mlir::acc::DataBoundsType::get($_builder.getContext()), /*lowerbound=*/{}, /*upperbound=*/{}, extent, - /*stride=*/{}, /*strideInBytes=*/nullptr, /*startIdx=*/{}); + /*stride=*/{}, /*strideInBytes=*/$_builder.getBoolAttr(false), + /*startIdx=*/{}); }] >, OpBuilder<(ins "::mlir::Value":$lowerbound, @@ -361,7 +369,8 @@ def OpenACC_DataBoundsOp : OpenACC_Op<"bounds", build($_builder, $_state, ::mlir::acc::DataBoundsType::get($_builder.getContext()), lowerbound, upperbound, /*extent=*/{}, - /*stride=*/{}, /*strideInBytes=*/nullptr, /*startIdx=*/{}); + /*stride=*/{}, /*strideInBytes=*/$_builder.getBoolAttr(false), + /*startIdx=*/{}); }] > ]; @@ -396,10 +405,15 @@ class OpenACC_DataEntryOp getVarPtr() { + return mlir::dyn_cast>(getVar()); + } + mlir::TypedValue getAccPtr() { + return mlir::dyn_cast>(getAccVar()); + } }]; let assemblyFormat = [{ - `varPtr` `(` $varPtr `:` custom(type($varPtr), $varType) + custom($var) `:` custom(type($var), $varType) oilist( `varPtrPtr` `(` $varPtrPtr `:` type($varPtrPtr) `)` | `bounds` `(` $bounds `)` | `async` `(` custom($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType) `)` - ) `->` type($accPtr) attr-dict + ) `->` type($accVar) attr-dict }]; let hasVerifier = 1; - let builders = [OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured, - "bool":$implicit, - CArg<"::mlir::ValueRange", "{}">:$bounds), - [{ + let builders = [ + OpBuilder<(ins "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr, + "bool":$structured, "bool":$implicit, + CArg<"::mlir::ValueRange", "{}">:$bounds), + [{ build($_builder, $_state, varPtr.getType(), varPtr, /*varType=*/::mlir::TypeAttr::get( - ::mlir::cast<::mlir::acc::PointerLikeType>( - varPtr.getType()).getElementType()), + varPtr.getType().getElementType()), /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr); }]>, - OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured, - "bool":$implicit, "const ::llvm::Twine &":$name, - CArg<"::mlir::ValueRange", "{}">:$bounds), - [{ + OpBuilder<(ins "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr, + "bool":$structured, "bool":$implicit, + "const ::llvm::Twine &":$name, + CArg<"::mlir::ValueRange", "{}">:$bounds), + [{ build($_builder, $_state, varPtr.getType(), varPtr, /*varType=*/::mlir::TypeAttr::get( - ::mlir::cast<::mlir::acc::PointerLikeType>( - varPtr.getType()).getElementType()), + varPtr.getType().getElementType()), + /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{}, + /*asyncOperandsDeviceType=*/nullptr, + /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, + /*structured=*/$_builder.getBoolAttr(structured), + /*implicit=*/$_builder.getBoolAttr(implicit), + /*name=*/$_builder.getStringAttr(name)); + }]>, + OpBuilder<(ins "::mlir::TypedValue<::mlir::acc::MappableType>":$var, + "bool":$structured, "bool":$implicit, + CArg<"::mlir::ValueRange", "{}">:$bounds), + [{ + build($_builder, $_state, var.getType(), var, + /*varType=*/::mlir::TypeAttr::get(var.getType()), + /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{}, + /*asyncOperandsDeviceType=*/nullptr, + /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, + /*structured=*/$_builder.getBoolAttr(structured), + /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr); + }]>, + OpBuilder<(ins "::mlir::TypedValue<::mlir::acc::MappableType>":$var, + "bool":$structured, "bool":$implicit, + "const ::llvm::Twine &":$name, + CArg<"::mlir::ValueRange", "{}">:$bounds), + [{ + build($_builder, $_state, var.getType(), var, + /*varType=*/::mlir::TypeAttr::get(var.getType()), /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, @@ -506,10 +552,10 @@ class OpenACC_DataEntryOp { + (ins OpenACC_AnyPointerOrMappableType:$var)> { let summary = "Represents private semantics for acc private clause."; - let results = (outs Arg:$accPtr); + let results = (outs Arg:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -518,11 +564,11 @@ def OpenACC_PrivateOp : OpenACC_DataEntryOp<"private", //===----------------------------------------------------------------------===// def OpenACC_FirstprivateOp : OpenACC_DataEntryOp<"firstprivate", "mlir::acc::DataClause::acc_firstprivate", "", [], - (ins Arg:$varPtr)> { + (ins Arg:$var)> { let summary = "Represents firstprivate semantic for the acc firstprivate " "clause."; - let results = (outs Arg:$accPtr); + let results = (outs Arg:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -531,10 +577,10 @@ def OpenACC_FirstprivateOp : OpenACC_DataEntryOp<"firstprivate", //===----------------------------------------------------------------------===// def OpenACC_ReductionOp : OpenACC_DataEntryOp<"reduction", "mlir::acc::DataClause::acc_reduction", "", [], - (ins Arg:$varPtr)> { + (ins Arg:$var)> { let summary = "Represents reduction semantics for acc reduction clause."; - let results = (outs Arg:$accPtr); + let results = (outs Arg:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -544,9 +590,9 @@ def OpenACC_ReductionOp : OpenACC_DataEntryOp<"reduction", def OpenACC_DevicePtrOp : OpenACC_DataEntryOp<"deviceptr", "mlir::acc::DataClause::acc_deviceptr", "", [MemoryEffects<[MemRead]>], - (ins OpenACC_PointerLikeTypeInterface:$varPtr)> { + (ins OpenACC_AnyPointerOrMappableType:$var)> { let summary = "Specifies that the variable pointer is a device pointer."; - let results = (outs OpenACC_PointerLikeTypeInterface:$accPtr); + let results = (outs OpenACC_AnyPointerOrMappableType:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -557,9 +603,9 @@ def OpenACC_PresentOp : OpenACC_DataEntryOp<"present", "mlir::acc::DataClause::acc_present", "", [MemoryEffects<[MemRead, MemWrite]>], - (ins OpenACC_PointerLikeTypeInterface:$varPtr)> { + (ins OpenACC_AnyPointerOrMappableType:$var)> { let summary = "Specifies that the variable is already present on device."; - let results = (outs OpenACC_PointerLikeTypeInterface:$accPtr); + let results = (outs OpenACC_AnyPointerOrMappableType:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -570,11 +616,11 @@ def OpenACC_CopyinOp : OpenACC_DataEntryOp<"copyin", "mlir::acc::DataClause::acc_copyin", "", [MemoryEffects<[MemRead, MemWrite]>], - (ins Arg:$varPtr)> { + (ins Arg:$var)> { let summary = "Represents copyin semantics for acc data clauses like acc " "copyin and acc copy."; - let results = (outs Arg:$accPtr); + let results = (outs Arg:$accVar); let extraClassDeclaration = extraClassDeclarationBase # [{ /// Check if this is a copyin with readonly modifier. @@ -589,11 +635,11 @@ def OpenACC_CreateOp : OpenACC_DataEntryOp<"create", "mlir::acc::DataClause::acc_create", "", [MemoryEffects<[MemRead, MemWrite]>], - (ins OpenACC_PointerLikeTypeInterface:$varPtr)> { + (ins OpenACC_AnyPointerOrMappableType:$var)> { let summary = "Represents create semantics for acc data clauses like acc " "create and acc copyout."; - let results = (outs Arg:$accPtr); + let results = (outs Arg:$accVar); let extraClassDeclaration = extraClassDeclarationBase # [{ /// Check if this is a create with zero modifier. @@ -608,9 +654,9 @@ def OpenACC_NoCreateOp : OpenACC_DataEntryOp<"nocreate", "mlir::acc::DataClause::acc_no_create", "", [MemoryEffects<[MemRead, MemWrite]>], - (ins OpenACC_PointerLikeTypeInterface:$varPtr)> { + (ins OpenACC_AnyPointerOrMappableType:$var)> { let summary = "Represents acc no_create semantics."; - let results = (outs OpenACC_PointerLikeTypeInterface:$accPtr); + let results = (outs OpenACC_AnyPointerOrMappableType:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -621,11 +667,11 @@ def OpenACC_AttachOp : OpenACC_DataEntryOp<"attach", "mlir::acc::DataClause::acc_attach", "", [MemoryEffects<[MemRead, MemWrite]>], - (ins Arg:$varPtr)> { + (ins Arg:$var)> { let summary = "Represents acc attach semantics which updates a pointer in " "device memory with the corresponding device address of the " "pointee."; - let results = (outs OpenACC_PointerLikeTypeInterface:$accPtr); + let results = (outs OpenACC_AnyPointerOrMappableType:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -645,9 +691,9 @@ def OpenACC_GetDevicePtrOp : OpenACC_DataEntryOp<"getdeviceptr", that is any of the valid `mlir::acc::DataClause` entries. \ }], [MemoryEffects<[MemRead]>], - (ins OpenACC_PointerLikeTypeInterface:$varPtr)> { + (ins OpenACC_AnyPointerOrMappableType:$var)> { let summary = "Gets device address if variable exists on device."; - let results = (outs OpenACC_PointerLikeTypeInterface:$accPtr); + let results = (outs OpenACC_AnyPointerOrMappableType:$accVar); let hasVerifier = 0; let extraClassDeclaration = extraClassDeclarationBase; } @@ -657,10 +703,10 @@ def OpenACC_GetDevicePtrOp : OpenACC_DataEntryOp<"getdeviceptr", //===----------------------------------------------------------------------===// def OpenACC_UpdateDeviceOp : OpenACC_DataEntryOp<"update_device", "mlir::acc::DataClause::acc_update_device", "", [], - (ins Arg:$varPtr)> { + (ins Arg:$var)> { let summary = "Represents acc update device semantics."; - let results = (outs Arg:$accPtr); + let results = (outs Arg:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -670,9 +716,9 @@ def OpenACC_UpdateDeviceOp : OpenACC_DataEntryOp<"update_device", def OpenACC_UseDeviceOp : OpenACC_DataEntryOp<"use_device", "mlir::acc::DataClause::acc_use_device", "", [MemoryEffects<[MemRead]>], - (ins OpenACC_PointerLikeTypeInterface:$varPtr)> { + (ins OpenACC_AnyPointerOrMappableType:$var)> { let summary = "Represents acc use_device semantics."; - let results = (outs OpenACC_PointerLikeTypeInterface:$accPtr); + let results = (outs OpenACC_AnyPointerOrMappableType:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -682,9 +728,9 @@ def OpenACC_UseDeviceOp : OpenACC_DataEntryOp<"use_device", def OpenACC_DeclareDeviceResidentOp : OpenACC_DataEntryOp<"declare_device_resident", "mlir::acc::DataClause::acc_declare_device_resident", "", [MemoryEffects<[MemWrite]>], - (ins Arg:$varPtr)> { + (ins Arg:$var)> { let summary = "Represents acc declare device_resident semantics."; - let results = (outs OpenACC_PointerLikeTypeInterface:$accPtr); + let results = (outs OpenACC_AnyPointerOrMappableType:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -694,9 +740,9 @@ def OpenACC_DeclareDeviceResidentOp : OpenACC_DataEntryOp<"declare_device_reside def OpenACC_DeclareLinkOp : OpenACC_DataEntryOp<"declare_link", "mlir::acc::DataClause::acc_declare_link", "", [MemoryEffects<[MemWrite]>], - (ins Arg:$varPtr)> { + (ins Arg:$var)> { let summary = "Represents acc declare link semantics."; - let results = (outs OpenACC_PointerLikeTypeInterface:$accPtr); + let results = (outs OpenACC_AnyPointerOrMappableType:$accVar); let extraClassDeclaration = extraClassDeclarationBase; } @@ -705,10 +751,10 @@ def OpenACC_DeclareLinkOp : OpenACC_DataEntryOp<"declare_link", //===----------------------------------------------------------------------===// def OpenACC_CacheOp : OpenACC_DataEntryOp<"cache", "mlir::acc::DataClause::acc_cache", "", [NoMemoryEffect], - (ins OpenACC_PointerLikeTypeInterface:$varPtr)> { + (ins OpenACC_AnyPointerOrMappableType:$var)> { let summary = "Represents the cache directive that is associated with a " "loop."; - let results = (outs OpenACC_PointerLikeTypeInterface:$accPtr); + let results = (outs OpenACC_AnyPointerOrMappableType:$accVar); let extraClassDeclaration = extraClassDeclarationBase # [{ /// Check if this is a cache with readonly modifier. @@ -738,7 +784,7 @@ class OpenACC_DataExitOp:$name)); let description = !strconcat(extraDescription, [{ - - `accPtr`: The acc address of variable. This is the link from the data-entry + - `accVar`: The acc variable. This is the link from the data-entry operation used. - `bounds`: Used when copying just slice of array or array's bounds are not encoded in type. They are in rank order where rank 0 is inner-most dimension. @@ -808,57 +854,67 @@ class OpenACC_DataExitOpWithVarPtr [MemoryEffects<[MemRead, MemWrite]>], (ins Arg:$accPtr, + "Accelerator mapped variable", [MemRead]>:$accVar, Arg:$varPtr, + "Host variable", [MemWrite]>:$var, TypeAttr:$varType)> { let assemblyFormat = [{ - `accPtr` `(` $accPtr `:` type($accPtr) `)` + custom($accVar, type($accVar)) (`bounds` `(` $bounds^ `)` )? (`async` `(` custom($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType)^ `)`)? - `to` `varPtr` `(` $varPtr `:` custom(type($varPtr), $varType) + `to` custom($var) `:` custom(type($var), $varType) attr-dict }]; - let builders = [OpBuilder<(ins "::mlir::Value":$accPtr, - "::mlir::Value":$varPtr, "bool":$structured, - "bool":$implicit, - CArg<"::mlir::ValueRange", "{}">:$bounds), - [{ + let builders = [ + OpBuilder<(ins "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$accPtr, + "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr, + "bool":$structured, "bool":$implicit, + CArg<"::mlir::ValueRange", "{}">:$bounds), + [{ build($_builder, $_state, accPtr, varPtr, /*varType=*/::mlir::TypeAttr::get( - ::mlir::cast<::mlir::acc::PointerLikeType>( - varPtr.getType()).getElementType()), + varPtr.getType().getElementType()), bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr); }]>, - OpBuilder<(ins "::mlir::Value":$accPtr, - "::mlir::Value":$varPtr, "bool":$structured, - "bool":$implicit, "const ::llvm::Twine &":$name, - CArg<"::mlir::ValueRange", "{}">:$bounds), - [{ + OpBuilder<(ins "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$accPtr, + "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr, + "bool":$structured, "bool":$implicit, + "const ::llvm::Twine &":$name, + CArg<"::mlir::ValueRange", "{}">:$bounds), + [{ build($_builder, $_state, accPtr, varPtr, /*varType=*/::mlir::TypeAttr::get( - ::mlir::cast<::mlir::acc::PointerLikeType>( - varPtr.getType()).getElementType()), + varPtr.getType().getElementType()), bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/$_builder.getStringAttr(name)); }]>]; + + code extraClassDeclarationDataExit = [{ + mlir::TypedValue getVarPtr() { + return mlir::dyn_cast>(getVar()); + } + mlir::TypedValue getAccPtr() { + return mlir::dyn_cast>(getAccVar()); + } + }]; } class OpenACC_DataExitOpNoVarPtr : OpenACC_DataExitOp, MemWrite]>], - (ins Arg:$accPtr)> { + (ins Arg:$accVar)> { let assemblyFormat = [{ - `accPtr` `(` $accPtr `:` type($accPtr) `)` + custom($accVar, type($accVar)) (`bounds` `(` $bounds^ `)` )? (`async` `(` custom($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType)^ `)`)? @@ -866,31 +922,35 @@ class OpenACC_DataExitOpNoVarPtr : }]; let builders = [ - OpBuilder<(ins "::mlir::Value":$accPtr, - "bool":$structured, - "bool":$implicit, - CArg<"::mlir::ValueRange", "{}">:$bounds), [{ + OpBuilder<(ins "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$accPtr, + "bool":$structured, "bool":$implicit, + CArg<"::mlir::ValueRange", "{}">:$bounds), + [{ build($_builder, $_state, accPtr, bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr); - }] - >, - OpBuilder<(ins "::mlir::Value":$accPtr, - "bool":$structured, - "bool":$implicit, - "const ::llvm::Twine &":$name, - CArg<"::mlir::ValueRange", "{}">:$bounds), [{ + }]>, + OpBuilder<(ins "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$accPtr, + "bool":$structured, "bool":$implicit, + "const ::llvm::Twine &":$name, + CArg<"::mlir::ValueRange", "{}">:$bounds), + [{ build($_builder, $_state, accPtr, bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/$_builder.getStringAttr(name)); - }] - > + }]> ]; + + code extraClassDeclarationDataExit = [{ + mlir::TypedValue getAccPtr() { + return mlir::dyn_cast>(getAccVar()); + } + }]; } //===----------------------------------------------------------------------===// @@ -900,7 +960,7 @@ def OpenACC_CopyoutOp : OpenACC_DataExitOpWithVarPtr<"copyout", "mlir::acc::DataClause::acc_copyout"> { let summary = "Represents acc copyout semantics - reverse of copyin."; - let extraClassDeclaration = extraClassDeclarationBase # [{ + let extraClassDeclaration = extraClassDeclarationBase # extraClassDeclarationDataExit # [{ /// Check if this is a copyout with zero modifier. bool isCopyoutZero(); }]; @@ -912,7 +972,7 @@ def OpenACC_CopyoutOp : OpenACC_DataExitOpWithVarPtr<"copyout", def OpenACC_DeleteOp : OpenACC_DataExitOpNoVarPtr<"delete", "mlir::acc::DataClause::acc_delete"> { let summary = "Represents acc delete semantics - reverse of create."; - let extraClassDeclaration = extraClassDeclarationBase; + let extraClassDeclaration = extraClassDeclarationBase # extraClassDeclarationDataExit; } //===----------------------------------------------------------------------===// @@ -921,7 +981,7 @@ def OpenACC_DeleteOp : OpenACC_DataExitOpNoVarPtr<"delete", def OpenACC_DetachOp : OpenACC_DataExitOpNoVarPtr<"detach", "mlir::acc::DataClause::acc_detach"> { let summary = "Represents acc detach semantics - reverse of attach."; - let extraClassDeclaration = extraClassDeclarationBase; + let extraClassDeclaration = extraClassDeclarationBase # extraClassDeclarationDataExit; } //===----------------------------------------------------------------------===// @@ -930,7 +990,7 @@ def OpenACC_DetachOp : OpenACC_DataExitOpNoVarPtr<"detach", def OpenACC_UpdateHostOp : OpenACC_DataExitOpWithVarPtr<"update_host", "mlir::acc::DataClause::acc_update_host"> { let summary = "Represents acc update host semantics."; - let extraClassDeclaration = extraClassDeclarationBase # [{ + let extraClassDeclaration = extraClassDeclarationBase # extraClassDeclarationDataExit # [{ /// Check if this is an acc update self. bool isSelf() { return getDataClause() == acc::DataClause::acc_update_self; @@ -1193,11 +1253,11 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", UnitAttr:$selfAttr, Variadic:$reductionOperands, OptionalAttr:$reductionRecipes, - Variadic:$privateOperands, + Variadic:$privateOperands, OptionalAttr:$privatizations, - Variadic:$firstprivateOperands, + Variadic:$firstprivateOperands, OptionalAttr:$firstprivatizations, - Variadic:$dataClauseOperands, + Variadic:$dataClauseOperands, OptionalAttr:$defaultAttr, UnitAttr:$combined); @@ -1353,11 +1413,11 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", UnitAttr:$selfAttr, Variadic:$reductionOperands, OptionalAttr:$reductionRecipes, - Variadic:$privateOperands, + Variadic:$privateOperands, OptionalAttr:$privatizations, - Variadic:$firstprivateOperands, + Variadic:$firstprivateOperands, OptionalAttr:$firstprivatizations, - Variadic:$dataClauseOperands, + Variadic:$dataClauseOperands, OptionalAttr:$defaultAttr, UnitAttr:$combined); @@ -1482,7 +1542,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", Optional:$ifCond, Optional:$selfCond, UnitAttr:$selfAttr, - Variadic:$dataClauseOperands, + Variadic:$dataClauseOperands, OptionalAttr:$defaultAttr, UnitAttr:$combined); @@ -1614,7 +1674,7 @@ def OpenACC_DataOp : OpenACC_Op<"data", OptionalAttr:$waitOperandsDeviceType, OptionalAttr:$hasWaitDevnum, OptionalAttr:$waitOnly, - Variadic:$dataClauseOperands, + Variadic:$dataClauseOperands, OptionalAttr:$defaultAttr); let regions = (region AnyRegion:$region); @@ -1709,7 +1769,7 @@ def OpenACC_EnterDataOp : OpenACC_Op<"enter_data", Optional:$waitDevnum, Variadic:$waitOperands, UnitAttr:$wait, - Variadic:$dataClauseOperands); + Variadic:$dataClauseOperands); let extraClassDeclaration = [{ /// The number of data operands. @@ -1760,7 +1820,7 @@ def OpenACC_ExitDataOp : OpenACC_Op<"exit_data", Optional:$waitDevnum, Variadic:$waitOperands, UnitAttr:$wait, - Variadic:$dataClauseOperands, + Variadic:$dataClauseOperands, UnitAttr:$finalize); let extraClassDeclaration = [{ @@ -1810,7 +1870,7 @@ def OpenACC_HostDataOp : OpenACC_Op<"host_data", }]; let arguments = (ins Optional:$ifCond, - Variadic:$dataClauseOperands, + Variadic:$dataClauseOperands, UnitAttr:$ifPresent); let regions = (region AnyRegion:$region); @@ -1887,8 +1947,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", Variadic:$tileOperands, OptionalAttr:$tileOperandsSegments, OptionalAttr:$tileOperandsDeviceType, - Variadic:$cacheOperands, - Variadic:$privateOperands, + Variadic:$cacheOperands, + Variadic:$privateOperands, OptionalAttr:$privatizations, Variadic:$reductionOperands, OptionalAttr:$reductionRecipes, @@ -2201,7 +2261,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", ``` }]; - let arguments = (ins Variadic:$dataClauseOperands); + let arguments = (ins Variadic:$dataClauseOperands); let results = (outs OpenACC_DeclareTokenType:$token); let assemblyFormat = [{ @@ -2236,7 +2296,7 @@ def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", let arguments = (ins Optional:$token, - Variadic:$dataClauseOperands); + Variadic:$dataClauseOperands); let assemblyFormat = [{ oilist( @@ -2338,7 +2398,7 @@ def OpenACC_DeclareOp : OpenACC_Op<"declare", }]; let arguments = (ins - Variadic:$dataClauseOperands); + Variadic:$dataClauseOperands); let regions = (region AnyRegion:$region); @@ -2589,7 +2649,7 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", OptionalAttr:$waitOperandsDeviceType, OptionalAttr:$hasWaitDevnum, OptionalAttr:$waitOnly, - Variadic:$dataClauseOperands, + Variadic:$dataClauseOperands, UnitAttr:$ifPresent); let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td index 0a3edd5637704..bec46be89f058 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td @@ -31,4 +31,97 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> { ]; } +def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface to capture type-based semantics for mapping in a manner that + makes it convertible to size-based semantics. + }]; + + let methods = [ + InterfaceMethod< + /*description=*/[{ + Returns the pointer to the `var` if recoverable (such as in cases + where the current operation is a load from a memory slot). + }], + /*retTy=*/"::mlir::TypedValue<::mlir::acc::PointerLikeType>", + /*methodName=*/"getVarPtr", + /*args=*/(ins "::mlir::Value":$var), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (auto ptr = mlir::dyn_cast>( + var)) + return ptr; + return {}; + }] + >, + InterfaceMethod< + /*description=*/[{ + Returns the size in bytes when computable. If this is an array-like + type, avoiding passing `accBounds` ensures a computation of the size + of whole type. + }], + /*retTy=*/"::std::optional<::llvm::TypeSize>", + /*methodName=*/"getSizeInBytes", + /*args=*/(ins "::mlir::Value":$var, + "::mlir::ValueRange":$accBounds, + "const ::mlir::DataLayout &":$dataLayout), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Bounds operations are typically created for array types. In the + // generic implementation, it is not straightforward to distinguish + // between array types and ensure the size and offset take into account + // just the slice requested. Thus return not-computable for now. + if (!accBounds.empty()) + return {}; + return {dataLayout.getTypeSize($_type)}; + }] + >, + InterfaceMethod< + /*description=*/[{ + Returns the offset in bytes when computable. + }], + /*retTy=*/"::std::optional<::int64_t>", + /*methodName=*/"getOffsetInBytes", + /*args=*/(ins "::mlir::Value":$var, + "::mlir::ValueRange":$accBounds, + "const ::mlir::DataLayout &":$dataLayout), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Bounds operations are typically created for array types. In the + // generic implementation, it is not straightforward to distinguish + // between array types and ensure the size and offset take into account + // just the slice requested. Thus return not-computable for now. + if (!accBounds.empty()) + return {}; + + // If the type size is computable, it means it is trivial. Assume + // offset of 0. + if (::mlir::cast<::mlir::acc::MappableType>($_type).getSizeInBytes( + var, accBounds, dataLayout).has_value()) { + return {0}; + } + + return {}; + }] + >, + InterfaceMethod< + /*description=*/[{ + Returns explicit `acc.bounds` operations that envelop the whole + data structure. These operations are inserted using the provided builder + at the location set before calling this API. + }], + /*retTy=*/"::llvm::SmallVector<::mlir::Value>", + /*methodName=*/"generateAccBounds", + /*args=*/(ins "::mlir::Value":$var, + "::mlir::OpBuilder &":$builder), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return {}; + }] + >, + ]; +} + #endif // OPENACC_TYPE_INTERFACES diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index d490376db0e7f..735dd4fd7b8bb 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/TypeSwitch.h" @@ -192,6 +193,97 @@ static LogicalResult checkWaitAndAsyncConflict(Op op) { return success(); } +template +static LogicalResult checkVarAndVarType(Op op) { + if (!op.getVar()) + return op.emitError("must have var operand"); + + if (mlir::isa(op.getVar().getType()) && + mlir::isa(op.getVar().getType())) { + // TODO: If a type implements both interfaces (mappable and pointer-like), + // it is unclear which semantics to apply without additional info which + // would need captured in the data operation. For now restrict this case + // unless a compelling reason to support disambiguating between the two. + return op.emitError("var must be mappable or pointer-like (not both)"); + } + + if (!mlir::isa(op.getVar().getType()) && + !mlir::isa(op.getVar().getType())) + return op.emitError("var must be mappable or pointer-like"); + + if (mlir::isa(op.getVar().getType()) && + op.getVarType() != op.getVar().getType()) + return op.emitError("varType must match when var is mappable"); + + return success(); +} + +template +static LogicalResult checkVarAndAccVar(Op op) { + if (op.getVar().getType() != op.getAccVar().getType()) + return op.emitError("input and output types must match"); + + return success(); +} + +static ParseResult parseVar(mlir::OpAsmParser &parser, + OpAsmParser::UnresolvedOperand &var) { + // Either `var` or `varPtr` keyword is required. + if (failed(parser.parseOptionalKeyword("varPtr"))) { + if (failed(parser.parseKeyword("var"))) + return failure(); + } + if (failed(parser.parseLParen())) + return failure(); + if (failed(parser.parseOperand(var))) + return failure(); + + return success(); +} + +static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::Value var) { + if (mlir::isa(var.getType())) + p << "varPtr("; + else + p << "var("; + p.printOperand(var); +} + +static ParseResult parseAccVar(mlir::OpAsmParser &parser, + OpAsmParser::UnresolvedOperand &var, + mlir::Type &accVarType) { + // Either `accVar` or `accPtr` keyword is required. + if (failed(parser.parseOptionalKeyword("accPtr"))) { + if (failed(parser.parseKeyword("accVar"))) + return failure(); + } + if (failed(parser.parseLParen())) + return failure(); + if (failed(parser.parseOperand(var))) + return failure(); + if (failed(parser.parseColon())) + return failure(); + if (failed(parser.parseType(accVarType))) + return failure(); + if (failed(parser.parseRParen())) + return failure(); + + return success(); +} + +static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::Value accVar, mlir::Type accVarType) { + if (mlir::isa(accVar.getType())) + p << "accPtr("; + else + p << "accVar("; + p.printOperand(accVar); + p << " : "; + p.printType(accVarType); + p << ")"; +} + static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr) { @@ -211,8 +303,11 @@ static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, return failure(); } else { // Set `varType` from the element type of the type of `varPtr`. - varTypeAttr = mlir::TypeAttr::get( - mlir::cast(varPtrType).getElementType()); + if (mlir::isa(varPtrType)) + varTypeAttr = mlir::TypeAttr::get( + mlir::cast(varPtrType).getElementType()); + else + varTypeAttr = mlir::TypeAttr::get(varPtrType); } return success(); @@ -226,8 +321,11 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, // Print the `varType` only if it differs from the element type of // `varPtr`'s type. mlir::Type varType = varTypeAttr.getValue(); - if (mlir::cast(varPtrType).getElementType() != - varType) { + mlir::Type typeToCheckAgainst = + mlir::isa(varPtrType) + ? mlir::cast(varPtrType).getElementType() + : varPtrType; + if (typeToCheckAgainst != varType) { p << " varType("; p.printType(varType); p << ")"; @@ -252,6 +350,8 @@ LogicalResult acc::PrivateOp::verify() { if (getDataClause() != acc::DataClause::acc_private) return emitError( "data clause associated with private operation must match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); return success(); } @@ -262,6 +362,8 @@ LogicalResult acc::FirstprivateOp::verify() { if (getDataClause() != acc::DataClause::acc_firstprivate) return emitError("data clause associated with firstprivate operation must " "match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); return success(); } @@ -272,6 +374,8 @@ LogicalResult acc::ReductionOp::verify() { if (getDataClause() != acc::DataClause::acc_reduction) return emitError("data clause associated with reduction operation must " "match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); return success(); } @@ -282,6 +386,10 @@ LogicalResult acc::DevicePtrOp::verify() { if (getDataClause() != acc::DataClause::acc_deviceptr) return emitError("data clause associated with deviceptr operation must " "match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -292,6 +400,10 @@ LogicalResult acc::PresentOp::verify() { if (getDataClause() != acc::DataClause::acc_present) return emitError( "data clause associated with present operation must match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -307,6 +419,10 @@ LogicalResult acc::CopyinOp::verify() { return emitError( "data clause associated with copyin operation must match its intent" " or specify original clause this operation was decomposed from"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -326,6 +442,10 @@ LogicalResult acc::CreateOp::verify() { return emitError( "data clause associated with create operation must match its intent" " or specify original clause this operation was decomposed from"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -342,6 +462,10 @@ LogicalResult acc::NoCreateOp::verify() { if (getDataClause() != acc::DataClause::acc_no_create) return emitError("data clause associated with no_create operation must " "match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -352,6 +476,10 @@ LogicalResult acc::AttachOp::verify() { if (getDataClause() != acc::DataClause::acc_attach) return emitError( "data clause associated with attach operation must match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -363,6 +491,10 @@ LogicalResult acc::DeclareDeviceResidentOp::verify() { if (getDataClause() != acc::DataClause::acc_declare_device_resident) return emitError("data clause associated with device_resident operation " "must match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -374,6 +506,10 @@ LogicalResult acc::DeclareLinkOp::verify() { if (getDataClause() != acc::DataClause::acc_declare_link) return emitError( "data clause associated with link operation must match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -389,8 +525,12 @@ LogicalResult acc::CopyoutOp::verify() { return emitError( "data clause associated with copyout operation must match its intent" " or specify original clause this operation was decomposed from"); - if (!getVarPtr() || !getAccPtr()) + if (!getVar() || !getAccVar()) return emitError("must have both host and device pointers"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -414,7 +554,7 @@ LogicalResult acc::DeleteOp::verify() { return emitError( "data clause associated with delete operation must match its intent" " or specify original clause this operation was decomposed from"); - if (!getAccPtr()) + if (!getAccVar()) return emitError("must have device pointer"); return success(); } @@ -429,7 +569,7 @@ LogicalResult acc::DetachOp::verify() { return emitError( "data clause associated with detach operation must match its intent" " or specify original clause this operation was decomposed from"); - if (!getAccPtr()) + if (!getAccVar()) return emitError("must have device pointer"); return success(); } @@ -444,8 +584,12 @@ LogicalResult acc::UpdateHostOp::verify() { return emitError( "data clause associated with host operation must match its intent" " or specify original clause this operation was decomposed from"); - if (!getVarPtr() || !getAccPtr()) + if (!getVar() || !getAccVar()) return emitError("must have both host and device pointers"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -458,6 +602,10 @@ LogicalResult acc::UpdateDeviceOp::verify() { return emitError( "data clause associated with device operation must match its intent" " or specify original clause this operation was decomposed from"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -470,6 +618,10 @@ LogicalResult acc::UseDeviceOp::verify() { return emitError( "data clause associated with use_device operation must match its intent" " or specify original clause this operation was decomposed from"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -483,6 +635,10 @@ LogicalResult acc::CacheOp::verify() { return emitError( "data clause associated with cache operation must match its intent" " or specify original clause this operation was decomposed from"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkVarAndAccVar(*this))) + return failure(); return success(); } @@ -502,7 +658,7 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, } static bool isComputeOperation(Operation *op) { - return isa(op); + return isa(op); } namespace { @@ -2917,20 +3073,56 @@ LogicalResult acc::WaitOp::verify() { // acc dialect utilities //===----------------------------------------------------------------------===// -mlir::Value mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) { - auto varPtr{llvm::TypeSwitch(accDataClauseOp) +mlir::TypedValue +mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) { + auto varPtr{llvm::TypeSwitch>( + accDataClauseOp) .Case( [&](auto entry) { return entry.getVarPtr(); }) .Case( [&](auto exit) { return exit.getVarPtr(); }) - .Default([&](mlir::Operation *) { return mlir::Value(); })}; + .Default([&](mlir::Operation *) { + return mlir::TypedValue(); + })}; return varPtr; } -mlir::Value mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) { - auto accPtr{llvm::TypeSwitch(accDataClauseOp) +mlir::Value mlir::acc::getVar(mlir::Operation *accDataClauseOp) { + auto varPtr{ + llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto entry) { return entry.getVar(); }) + .Default([&](mlir::Operation *) { return mlir::Value(); })}; + return varPtr; +} + +mlir::Type mlir::acc::getVarType(mlir::Operation *accDataClauseOp) { + auto varType{llvm::TypeSwitch(accDataClauseOp) + .Case( + [&](auto entry) { return entry.getVarType(); }) + .Case( + [&](auto exit) { return exit.getVarType(); }) + .Default([&](mlir::Operation *) { return mlir::Type(); })}; + return varType; +} + +mlir::TypedValue +mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) { + auto accPtr{llvm::TypeSwitch>( + accDataClauseOp) .Case( [&](auto dataClause) { return dataClause.getAccPtr(); }) + .Default([&](mlir::Operation *) { + return mlir::TypedValue(); + })}; + return accPtr; +} + +mlir::Value mlir::acc::getAccVar(mlir::Operation *accDataClauseOp) { + auto accPtr{llvm::TypeSwitch(accDataClauseOp) + .Case( + [&](auto dataClause) { return dataClause.getAccVar(); }) .Default([&](mlir::Operation *) { return mlir::Value(); })}; return accPtr; } diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp index fbdada9309d32..cfb8aa767b6f8 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp @@ -9,9 +9,11 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" #include "gtest/gtest.h" using namespace mlir; @@ -446,10 +448,12 @@ void testShortDataEntryOpBuilders(OpBuilder &b, MLIRContext &context, OwningOpRef varPtrOp = b.create(loc, memrefTy); - OwningOpRef op = b.create(loc, varPtrOp->getResult(), + TypedValue varPtr = + cast>(varPtrOp->getResult()); + OwningOpRef op = b.create(loc, varPtr, /*structured=*/true, /*implicit=*/true); - EXPECT_EQ(op->getVarPtr(), varPtrOp->getResult()); + EXPECT_EQ(op->getVarPtr(), varPtr); EXPECT_EQ(op->getType(), memrefTy); EXPECT_EQ(op->getDataClause(), dataClause); EXPECT_TRUE(op->getImplicit()); @@ -457,7 +461,7 @@ void testShortDataEntryOpBuilders(OpBuilder &b, MLIRContext &context, EXPECT_TRUE(op->getBounds().empty()); EXPECT_FALSE(op->getVarPtrPtr()); - OwningOpRef op2 = b.create(loc, varPtrOp->getResult(), + OwningOpRef op2 = b.create(loc, varPtr, /*structured=*/false, /*implicit=*/false); EXPECT_FALSE(op2->getImplicit()); EXPECT_FALSE(op2->getStructured()); @@ -467,13 +471,13 @@ void testShortDataEntryOpBuilders(OpBuilder &b, MLIRContext &context, OwningOpRef bounds = b.create(loc, extent->getResult()); OwningOpRef opWithBounds = - b.create(loc, varPtrOp->getResult(), + b.create(loc, varPtr, /*structured=*/true, /*implicit=*/true, bounds->getResult()); EXPECT_FALSE(opWithBounds->getBounds().empty()); EXPECT_EQ(opWithBounds->getBounds().back(), bounds->getResult()); OwningOpRef opWithName = - b.create(loc, varPtrOp->getResult(), + b.create(loc, varPtr, /*structured=*/true, /*implicit=*/true, "varName"); EXPECT_EQ(opWithName->getNameAttr().str(), "varName"); } @@ -516,23 +520,26 @@ void testShortDataExitOpBuilders(OpBuilder &b, MLIRContext &context, auto memrefTy = MemRefType::get({}, b.getI32Type()); OwningOpRef varPtrOp = b.create(loc, memrefTy); + TypedValue varPtr = + cast>(varPtrOp->getResult()); + OwningOpRef accPtrOp = b.create( - loc, varPtrOp->getResult(), /*structured=*/true, /*implicit=*/true); + loc, varPtr, /*structured=*/true, /*implicit=*/true); + TypedValue accPtr = + cast>(accPtrOp->getResult()); - OwningOpRef op = - b.create(loc, accPtrOp->getResult(), varPtrOp->getResult(), - /*structured=*/true, /*implicit=*/true); + OwningOpRef op = b.create(loc, accPtr, varPtr, + /*structured=*/true, /*implicit=*/true); - EXPECT_EQ(op->getVarPtr(), varPtrOp->getResult()); - EXPECT_EQ(op->getAccPtr(), accPtrOp->getResult()); + EXPECT_EQ(op->getVarPtr(), varPtr); + EXPECT_EQ(op->getAccPtr(), accPtr); EXPECT_EQ(op->getDataClause(), dataClause); EXPECT_TRUE(op->getImplicit()); EXPECT_TRUE(op->getStructured()); EXPECT_TRUE(op->getBounds().empty()); - OwningOpRef op2 = - b.create(loc, accPtrOp->getResult(), varPtrOp->getResult(), - /*structured=*/false, /*implicit=*/false); + OwningOpRef op2 = b.create(loc, accPtr, varPtr, + /*structured=*/false, /*implicit=*/false); EXPECT_FALSE(op2->getImplicit()); EXPECT_FALSE(op2->getStructured()); @@ -541,13 +548,13 @@ void testShortDataExitOpBuilders(OpBuilder &b, MLIRContext &context, OwningOpRef bounds = b.create(loc, extent->getResult()); OwningOpRef opWithBounds = - b.create(loc, accPtrOp->getResult(), varPtrOp->getResult(), + b.create(loc, accPtr, varPtr, /*structured=*/true, /*implicit=*/true, bounds->getResult()); EXPECT_FALSE(opWithBounds->getBounds().empty()); EXPECT_EQ(opWithBounds->getBounds().back(), bounds->getResult()); OwningOpRef opWithName = - b.create(loc, accPtrOp->getResult(), varPtrOp->getResult(), + b.create(loc, accPtr, varPtr, /*structured=*/true, /*implicit=*/true, "varName"); EXPECT_EQ(opWithName->getNameAttr().str(), "varName"); } @@ -565,19 +572,24 @@ void testShortDataExitNoVarPtrOpBuilders(OpBuilder &b, MLIRContext &context, auto memrefTy = MemRefType::get({}, b.getI32Type()); OwningOpRef varPtrOp = b.create(loc, memrefTy); + TypedValue varPtr = + cast>(varPtrOp->getResult()); + OwningOpRef accPtrOp = b.create( - loc, varPtrOp->getResult(), /*structured=*/true, /*implicit=*/true); + loc, varPtr, /*structured=*/true, /*implicit=*/true); + TypedValue accPtr = + cast>(accPtrOp->getResult()); - OwningOpRef op = b.create(loc, accPtrOp->getResult(), + OwningOpRef op = b.create(loc, accPtr, /*structured=*/true, /*implicit=*/true); - EXPECT_EQ(op->getAccPtr(), accPtrOp->getResult()); + EXPECT_EQ(op->getAccPtr(), accPtr); EXPECT_EQ(op->getDataClause(), dataClause); EXPECT_TRUE(op->getImplicit()); EXPECT_TRUE(op->getStructured()); EXPECT_TRUE(op->getBounds().empty()); - OwningOpRef op2 = b.create(loc, accPtrOp->getResult(), + OwningOpRef op2 = b.create(loc, accPtr, /*structured=*/false, /*implicit=*/false); EXPECT_FALSE(op2->getImplicit()); EXPECT_FALSE(op2->getStructured()); @@ -587,13 +599,13 @@ void testShortDataExitNoVarPtrOpBuilders(OpBuilder &b, MLIRContext &context, OwningOpRef bounds = b.create(loc, extent->getResult()); OwningOpRef opWithBounds = - b.create(loc, accPtrOp->getResult(), + b.create(loc, accPtr, /*structured=*/true, /*implicit=*/true, bounds->getResult()); EXPECT_FALSE(opWithBounds->getBounds().empty()); EXPECT_EQ(opWithBounds->getBounds().back(), bounds->getResult()); OwningOpRef opWithName = - b.create(loc, accPtrOp->getResult(), + b.create(loc, accPtr, /*structured=*/true, /*implicit=*/true, "varName"); EXPECT_EQ(opWithName->getNameAttr().str(), "varName"); } @@ -604,3 +616,75 @@ TEST_F(OpenACCOpsTest, shortDataExitOpNoVarPtrBuilder) { testShortDataExitNoVarPtrOpBuilders(b, context, loc, DataClause::acc_detach); } + +template +void testShortDataEntryOpBuildersMappableVar(OpBuilder &b, MLIRContext &context, + Location loc, + DataClause dataClause) { + auto int64Ty = b.getI64Type(); + auto memrefTy = MemRefType::get({}, int64Ty); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + SmallVector indices; + OwningOpRef loadVarOp = + b.create(loc, int64Ty, varPtrOp->getResult(), indices); + + EXPECT_TRUE(isMappableType(loadVarOp->getResult().getType())); + TypedValue var = + cast>(loadVarOp->getResult()); + OwningOpRef op = b.create(loc, var, + /*structured=*/true, /*implicit=*/true); + + EXPECT_EQ(op->getVar(), var); + EXPECT_EQ(op->getVarPtr(), nullptr); + EXPECT_EQ(op->getType(), int64Ty); + EXPECT_EQ(op->getVarType(), int64Ty); + EXPECT_EQ(op->getDataClause(), dataClause); + EXPECT_TRUE(op->getImplicit()); + EXPECT_TRUE(op->getStructured()); + EXPECT_TRUE(op->getBounds().empty()); + EXPECT_FALSE(op->getVarPtrPtr()); +} + +struct IntegerOpenACCMappableModel + : public mlir::acc::MappableType::ExternalModel {}; + +TEST_F(OpenACCOpsTest, mappableTypeBuilderDataEntry) { + // First, set up the test by attaching MappableInterface to IntegerType. + IntegerType i64ty = IntegerType::get(&context, 8); + ASSERT_FALSE(isMappableType(i64ty)); + IntegerType::attachInterface(context); + ASSERT_TRUE(isMappableType(i64ty)); + + testShortDataEntryOpBuildersMappableVar(b, context, loc, + DataClause::acc_private); + testShortDataEntryOpBuildersMappableVar( + b, context, loc, DataClause::acc_firstprivate); + testShortDataEntryOpBuildersMappableVar( + b, context, loc, DataClause::acc_reduction); + testShortDataEntryOpBuildersMappableVar( + b, context, loc, DataClause::acc_deviceptr); + testShortDataEntryOpBuildersMappableVar(b, context, loc, + DataClause::acc_present); + testShortDataEntryOpBuildersMappableVar(b, context, loc, + DataClause::acc_copyin); + testShortDataEntryOpBuildersMappableVar(b, context, loc, + DataClause::acc_create); + testShortDataEntryOpBuildersMappableVar( + b, context, loc, DataClause::acc_no_create); + testShortDataEntryOpBuildersMappableVar(b, context, loc, + DataClause::acc_attach); + testShortDataEntryOpBuildersMappableVar( + b, context, loc, DataClause::acc_getdeviceptr); + testShortDataEntryOpBuildersMappableVar( + b, context, loc, DataClause::acc_update_device); + testShortDataEntryOpBuildersMappableVar( + b, context, loc, DataClause::acc_use_device); + testShortDataEntryOpBuildersMappableVar( + b, context, loc, DataClause::acc_declare_device_resident); + testShortDataEntryOpBuildersMappableVar( + b, context, loc, DataClause::acc_declare_link); + testShortDataEntryOpBuildersMappableVar(b, context, loc, + DataClause::acc_cache); +}