Skip to content

[MLIR][LLVM] Import dereferenceable metadata from LLVM IR #130974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1267,4 +1267,28 @@ def WorkgroupAttributionAttr
let assemblyFormat = "`<` $num_elements `,` $element_type `>`";
}

//===----------------------------------------------------------------------===//
// DereferenceableAttr
//===----------------------------------------------------------------------===//

def LLVM_DereferenceableAttr : LLVM_Attr<"Dereferenceable", "dereferenceable"> {
let summary = "LLVM dereferenceable attribute";
let description = [{
Defines `dereferenceable` or `dereferenceable_or_null` metadata that can
be set via the `DereferenceableOpInterface` on an `inttoptr` operation or
on a `load` operation which loads a pointer. The attribute is used to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would add a check to the LoadOp verifier that checks the load returns a pointer if the attribute is set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I've added an interface verifier which performs the check.

denote that the result of these operations is dereferenceable up to a
certain number of bytes, represented by `$bytes`. The optional `$mayBeNull`
parameter is set to true if the attribute defines `dereferenceable_or_null`
metadata.

See the following links for more details:
https://llvm.org/docs/LangRef.html#dereferenceable-metadata
https://llvm.org/docs/LangRef.html#dereferenceable-or-null-metadata
}];
let parameters = (ins "uint64_t":$bytes,
DefaultValuedParameter<"bool", "false">:$mayBeNull);
let assemblyFormat = "`<` struct(params) `>`";
}

#endif // LLVMIR_ATTRDEFS
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ LogicalResult verifyAccessGroupOpInterface(Operation *op);
/// the alias analysis interface.
LogicalResult verifyAliasAnalysisOpInterface(Operation *op);

/// Verifies that the operation implementing the dereferenceable interface has
/// exactly one result of LLVM pointer type.
LogicalResult verifyDereferenceableOpInterface(Operation *op);

} // namespace detail
} // namespace LLVM
} // namespace mlir
Expand Down
37 changes: 37 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,43 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
];
}

def DereferenceableOpInterface : OpInterface<"DereferenceableOpInterface"> {
let description = [{
An interface for memory operations that can carry dereferenceable metadata.
It provides setters and getters for the operation's dereferenceable
attributes. The default implementations of the interface methods expect
the operation to have an attribute of type DereferenceableAttr.
}];

let cppNamespace = "::mlir::LLVM";
let verify = [{ return detail::verifyDereferenceableOpInterface($_op); }];

let methods = [
InterfaceMethod<
/*desc=*/ "Returns the dereferenceable attribute or nullptr",
/*returnType=*/ "::mlir::LLVM::DereferenceableAttr",
/*methodName=*/ "getDereferenceableOrNull",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious - why is this called getDereferencableOrNull when it could be just getDereferenceable? IT's not symmetric with the setter

Copy link
Contributor Author

@mihailo-stojanovic mihailo-stojanovic Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dereferenceable attribute is optional on both load and inttoptr ops, so querying the interface may return nullptr. The naming is in line with other "metadata interfaces".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remember this naming convention was chosen to avoid a name clash with a tablegen generated function. I could be wrong though.

/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getDereferenceableAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the dereferenceable attribute",
/*returnType=*/ "void",
/*methodName=*/ "setDereferenceable",
/*args=*/ (ins "::mlir::LLVM::DereferenceableAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
op.setDereferenceableAttr(attr);
}]
>
];
}

def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
let description = [{
An interface for operations receiving an exception behavior attribute
Expand Down
34 changes: 31 additions & 3 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
[DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>,
DeclareOpInterfaceMethods<DereferenceableOpInterface>]> {
dag args = (ins LLVM_AnyPointer:$addr,
OptionalAttr<I64Attr>:$alignment,
UnitAttr:$volatile_,
Expand All @@ -373,7 +374,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
UnitAttr:$invariantGroup,
DefaultValuedAttr<
AtomicOrdering, "AtomicOrdering::not_atomic">:$ordering,
OptionalAttr<StrAttr>:$syncscope);
OptionalAttr<StrAttr>:$syncscope,
OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs LLVM_LoadableType:$res);
Expand Down Expand Up @@ -407,6 +409,7 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
(`atomic` (`syncscope` `(` $syncscope^ `)`)? $ordering^)?
(`invariant` $invariant^)?
(`invariant_group` $invariantGroup^)?
(`dereferenceable` `` $dereferenceable^)?
attr-dict `:` qualified(type($addr)) `->` type($res)
}];
string llvmBuilder = [{
Expand All @@ -416,6 +419,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
llvm::MDNode *metadata = llvm::MDNode::get(inst->getContext(), std::nullopt);
inst->setMetadata(llvm::LLVMContext::MD_invariant_load, metadata);
}
if ($dereferenceable)
moduleTranslation.setDereferenceableMetadata(op, inst);
}] # setOrderingCode
# setSyncScopeCode
# setAlignmentCode
Expand Down Expand Up @@ -571,6 +576,29 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
}];
}

class LLVM_DereferenceableCastOp<string mnemonic, string instName, Type type,
Type resultType, list<Trait> traits = []> :
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<DereferenceableOpInterface>], traits)> {
let arguments = (ins type:$arg, OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$arg (`dereferenceable` `` $dereferenceable^)? attr-dict `:` type($arg) `to` type($res)";
string llvmInstName = instName;
string llvmBuilder = [{
auto *val = builder.Create}] # instName # [{($arg, $_resultType);
$res = val;
if ($dereferenceable) {
llvm::Instruction *inst = dyn_cast<llvm::Instruction>(val);
moduleTranslation.setDereferenceableMetadata(op, inst);
}
}];
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>(
$_location, $_resultType, $arg);
$res = op;
}];
}

def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
let hasFolder = 1;
Expand All @@ -583,7 +611,7 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "AddrSpaceCast",
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let hasFolder = 1;
}
def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "IntToPtr",
def LLVM_IntToPtrOp : LLVM_DereferenceableCastOp<"inttoptr", "IntToPtr",
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt",
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ class ModuleImport {
LoopAnnotationAttr translateLoopAnnotationAttr(const llvm::MDNode *node,
Location loc) const;

/// Returns the dereferenceable attribute that corresponds to the given LLVM
/// dereferenceable or dereferenceable_or_null metadata `node`. `kindID`
/// specifies the kind of the metadata node (dereferenceable or
/// dereferenceable_or_null).
FailureOr<DereferenceableAttr>
translateDereferenceableAttr(const llvm::MDNode *node, unsigned kindID);

/// Returns the alias scope attributes that map to the alias scope nodes
/// starting from the metadata `node`. Returns failure, if any of the
/// attributes cannot be found.
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ class ModuleTranslation {
/// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);

/// Sets LLVM dereferenceable metadata for operations that have
/// dereferenceable attributes.
void setDereferenceableMetadata(DereferenceableOpInterface op,
llvm::Instruction *inst);

/// Sets LLVM profiling metadata for operations that have branch weights.
void setBranchWeightsMetadata(BranchWeightOpInterface op);

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
isNonTemporal, isInvariant, isInvariantGroup, ordering,
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
/*dereferenceable=*/nullptr,
/*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
/*tbaa=*/nullptr);
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
return isArrayOf<TBAATagAttr>(op, tags);
}

//===----------------------------------------------------------------------===//
// DereferenceableOpInterface
//===----------------------------------------------------------------------===//

LogicalResult
mlir::LLVM::detail::verifyDereferenceableOpInterface(Operation *op) {
auto iface = cast<DereferenceableOpInterface>(op);

if (auto derefAttr = iface.getDereferenceableOrNull())
if (op->getNumResults() != 1 ||
!mlir::isa<LLVMPointerType>(op->getResult(0).getType()))
return op->emitOpError(
"expected op to return a single LLVM pointer type");

return success();
}

SmallVector<Value> mlir::LLVM::AtomicCmpXchgOp::getAccessedOperands() {
return {getPtr()};
}
Expand Down
28 changes: 28 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
llvm::LLVMContext::MD_loop,
llvm::LLVMContext::MD_noalias,
llvm::LLVMContext::MD_alias_scope,
llvm::LLVMContext::MD_dereferenceable,
llvm::LLVMContext::MD_dereferenceable_or_null,
context.getMDKindID(vecTypeHintMDName),
context.getMDKindID(workGroupSizeHintMDName),
context.getMDKindID(reqdWorkGroupSizeMDName),
Expand Down Expand Up @@ -188,6 +190,25 @@ static LogicalResult setAccessGroupsAttr(const llvm::MDNode *node,
return success();
}

/// Converts the given dereferenceable metadata node to a dereferenceable
/// attribute, and attaches it to the imported operation if the translation
/// succeeds. Returns failure if the LLVM IR metadata node is ill-formed.
static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
unsigned kindID, Operation *op,
LLVM::ModuleImport &moduleImport) {
auto dereferenceable =
moduleImport.translateDereferenceableAttr(node, kindID);
if (failed(dereferenceable))
return failure();

auto iface = dyn_cast<DereferenceableOpInterface>(op);
if (!iface)
return failure();

iface.setDereferenceable(*dereferenceable);
return success();
}

/// Converts the given loop metadata node to an MLIR loop annotation attribute
/// and attaches it to the imported operation if the translation succeeds.
/// Returns failure otherwise.
Expand Down Expand Up @@ -401,6 +422,13 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
return setAliasScopesAttr(node, op, moduleImport);
if (kind == llvm::LLVMContext::MD_noalias)
return setNoaliasScopesAttr(node, op, moduleImport);
if (kind == llvm::LLVMContext::MD_dereferenceable)
return setDereferenceableAttr(node, llvm::LLVMContext::MD_dereferenceable,
op, moduleImport);
if (kind == llvm::LLVMContext::MD_dereferenceable_or_null)
return setDereferenceableAttr(
node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
moduleImport);

llvm::LLVMContext &context = node->getContext();
if (kind == context.getMDKindID(vecTypeHintMDName))
Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2527,6 +2527,31 @@ ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
return loopAnnotationImporter->translateLoopAnnotation(node, loc);
}

FailureOr<DereferenceableAttr>
ModuleImport::translateDereferenceableAttr(const llvm::MDNode *node,
unsigned kindID) {
Location loc = mlirModule.getLoc();

// The only operand should be a constant integer representing the number of
// dereferenceable bytes.
if (node->getNumOperands() != 1)
return emitError(loc) << "dereferenceable metadata must have one operand: "
<< diagMD(node, llvmModule.get());

auto *numBytesMD = dyn_cast<llvm::ConstantAsMetadata>(node->getOperand(0));
auto *numBytesCst = dyn_cast<llvm::ConstantInt>(numBytesMD->getValue());
if (!numBytesCst || !numBytesCst->getValue().isNonNegative())
return emitError(loc) << "dereferenceable metadata operand must be a "
"non-negative constant integer: "
<< diagMD(node, llvmModule.get());

bool mayBeNull = kindID == llvm::LLVMContext::MD_dereferenceable_or_null;
auto derefAttr = builder.getAttr<DereferenceableAttr>(
numBytesCst->getZExtValue(), mayBeNull);

return derefAttr;
}

OwningOpRef<ModuleOp>
mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
MLIRContext *context, bool emitExpensiveWarnings,
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1925,6 +1925,22 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
}

void ModuleTranslation::setDereferenceableMetadata(
DereferenceableOpInterface op, llvm::Instruction *inst) {
DereferenceableAttr derefAttr = op.getDereferenceableOrNull();
if (!derefAttr)
return;

llvm::MDNode *derefSizeNode = llvm::MDNode::get(
getLLVMContext(),
llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
llvm::IntegerType::get(getLLVMContext(), 64), derefAttr.getBytes())));
unsigned kindId = derefAttr.getMayBeNull()
? llvm::LLVMContext::MD_dereferenceable_or_null
: llvm::LLVMContext::MD_dereferenceable;
inst->setMetadata(kindId, derefSizeNode);
}

void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
if (!weightsAttr)
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Dialect/LLVMIR/dereferenceable-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// RUN: mlir-opt --allow-unregistered-dialect -split-input-file -verify-diagnostics %s

llvm.func @deref(%arg0: !llvm.ptr) {
// expected-error @below {{op expected op to return a single LLVM pointer type}}
%0 = llvm.load %arg0 dereferenceable<bytes = 8> {alignment = 8 : i64} : !llvm.ptr -> i64
llvm.return
}
11 changes: 11 additions & 0 deletions mlir/test/Target/LLVMIR/Import/import-failure.ll
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,17 @@ declare void @llvm.experimental.noalias.scope.decl(metadata)

; // -----

; CHECK: import-failure.ll
; CHECK-SAME: dereferenceable metadata operand must be a non-negative constant integer
define void @deref(i64 %0) {
%2 = inttoptr i64 %0 to ptr, !dereferenceable !0
ret void
}

!0 = !{i64 -4}

; // -----

; CHECK: import-failure.ll
; CHECK-SAME: warning: unhandled data layout token: ni:42
target datalayout = "e-ni:42-i64:64"
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Target/LLVMIR/Import/metadata-dereferenceable.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s

define void @deref(i64 %0, ptr %1) {
; CHECK: llvm.inttoptr
; CHECK-SAME: dereferenceable<bytes = 4>
%3 = inttoptr i64 %0 to ptr, !dereferenceable !0
; CHECK: llvm.load
; CHECK-SAME: dereferenceable<bytes = 8>
%4 = load ptr, ptr %1, align 8, !dereferenceable !1
ret void
}

define void @deref_or_null(i64 %0, ptr %1) {
; CHECK: llvm.inttoptr
; CHECK-SAME: dereferenceable<bytes = 4, mayBeNull = true>
%3 = inttoptr i64 %0 to ptr, !dereferenceable_or_null !0
; CHECK: llvm.load
; CHECK-SAME: dereferenceable<bytes = 8, mayBeNull = true>
%4 = load ptr, ptr %1, align 8, !dereferenceable_or_null !1
ret void
}

!0 = !{i64 4}
!1 = !{i64 8}
24 changes: 24 additions & 0 deletions mlir/test/Target/LLVMIR/attribute-dereferenceable.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s

llvm.func @deref(%arg0: i64, %arg1: !llvm.ptr) {
// CHECK: inttoptr {{.*}} !dereferenceable [[D0:![0-9]+]]
%0 = llvm.inttoptr %arg0 dereferenceable<bytes = 4> : i64 to !llvm.ptr
%1 = llvm.load %0 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: load {{.*}} !dereferenceable [[D1:![0-9]+]]
%2 = llvm.load %arg1 dereferenceable<bytes = 8> {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
llvm.return
}

llvm.func @deref_or_null(%arg0: i64, %arg1: !llvm.ptr) {
// CHECK: inttoptr {{.*}} !dereferenceable_or_null [[D0]]
%0 = llvm.inttoptr %arg0 dereferenceable<bytes = 4, mayBeNull = true> : i64 to !llvm.ptr
%1 = llvm.load %0 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: load {{.*}} !dereferenceable_or_null [[D1]]
%2 = llvm.load %arg1 dereferenceable<bytes = 8, mayBeNull = true> {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
llvm.return
}

// CHECK: [[D0]] = !{i64 4}
// CHECK: [[D1]] = !{i64 8}