Skip to content

[mlir][operation] weak refs #97340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/MLIRContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ class DynamicDialect;
class InFlightDiagnostic;
class Location;
class MLIRContextImpl;
class Operation;
class RegisteredOperationName;
class StorageUniquer;
class IRUnit;
class WeakOpRef;

/// MLIRContext is the top-level object for a collection of MLIR operations. It
/// holds immortal uniqued objects like types, and the tables used to unique
Expand Down Expand Up @@ -275,6 +277,9 @@ class MLIRContext {
actionFn();
}

WeakOpRef acquireWeakOpRef(Operation *op);
void expireWeakRefs(Operation *op);

private:
/// Return true if the given dialect is currently loading.
bool isDialectLoading(StringRef dialectNamespace);
Expand Down
66 changes: 60 additions & 6 deletions mlir/include/mlir/IR/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,39 @@
#include <optional>

namespace mlir {
class WeakOpRef;
class WeakOpRefHolder {
private:
mlir::Operation *op;

public:
WeakOpRefHolder(mlir::Operation *op) : op(op) {}
~WeakOpRefHolder();
friend class WeakOpRef;
};

class WeakOpRef {
private:
std::shared_ptr<WeakOpRefHolder> holder;

public:
WeakOpRef(std::shared_ptr<WeakOpRefHolder> const &r);

WeakOpRef(WeakOpRef const &r);
WeakOpRef(WeakOpRef &&r);
~WeakOpRef();

WeakOpRef &operator=(WeakOpRef const &r);
WeakOpRef &operator=(WeakOpRef &&r);

void swap(WeakOpRef &r);
bool expired() const;
long use_count() const { return holder ? holder.use_count() : 0; }

mlir::Operation *operator->() const;
mlir::Operation &operator*() const;
};

namespace detail {
/// This is a "tag" used for mapping the properties storage in
/// llvm::TrailingObjects.
Expand Down Expand Up @@ -210,7 +243,7 @@ class alignas(8) Operation final
Operation *cloneWithoutRegions();

/// Returns the operation block that contains this operation.
Block *getBlock() { return block; }
Block *getBlock() { return blockHasWeakRefPair.getPointer(); }

/// Return the context this operation is associated with.
MLIRContext *getContext() { return location->getContext(); }
Expand All @@ -227,11 +260,15 @@ class alignas(8) Operation final

/// Returns the region to which the instruction belongs. Returns nullptr if
/// the instruction is unlinked.
Region *getParentRegion() { return block ? block->getParent() : nullptr; }
Region *getParentRegion() {
return getBlock() ? getBlock()->getParent() : nullptr;
}

/// Returns the closest surrounding operation that contains this operation
/// or nullptr if this is a top-level operation.
Operation *getParentOp() { return block ? block->getParentOp() : nullptr; }
Operation *getParentOp() {
return getBlock() ? getBlock()->getParentOp() : nullptr;
}

/// Return the closest surrounding parent operation that is of type 'OpTy'.
template <typename OpTy>
Expand Down Expand Up @@ -545,6 +582,7 @@ class alignas(8) Operation final
AttrClass getAttrOfType(StringAttr name) {
return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
}

template <typename AttrClass>
AttrClass getAttrOfType(StringRef name) {
return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
Expand All @@ -559,13 +597,15 @@ class alignas(8) Operation final
}
return attrs.contains(name);
}

bool hasAttr(StringRef name) {
if (getPropertiesStorageSize()) {
if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
return (bool)*inherentAttr;
}
return attrs.contains(name);
}

template <typename AttrClass, typename NameT>
bool hasAttrOfType(NameT &&name) {
return static_cast<bool>(
Expand All @@ -585,6 +625,7 @@ class alignas(8) Operation final
if (attributes.set(name, value) != value)
attrs = attributes.getDictionary(getContext());
}

void setAttr(StringRef name, Attribute value) {
setAttr(StringAttr::get(getContext(), name), value);
}
Expand All @@ -605,6 +646,7 @@ class alignas(8) Operation final
attrs = attributes.getDictionary(getContext());
return removedAttr;
}

Attribute removeAttr(StringRef name) {
return removeAttr(StringAttr::get(getContext(), name));
}
Expand All @@ -626,6 +668,7 @@ class alignas(8) Operation final
// Allow access to the constructor.
friend Operation;
};

using dialect_attr_range = iterator_range<dialect_attr_iterator>;

/// Return a range corresponding to the dialect attributes for this operation.
Expand All @@ -634,10 +677,12 @@ class alignas(8) Operation final
return {dialect_attr_iterator(attrs.begin(), attrs.end()),
dialect_attr_iterator(attrs.end(), attrs.end())};
}

dialect_attr_iterator dialect_attr_begin() {
auto attrs = getAttrs();
return dialect_attr_iterator(attrs.begin(), attrs.end());
}

dialect_attr_iterator dialect_attr_end() {
auto attrs = getAttrs();
return dialect_attr_iterator(attrs.end(), attrs.end());
Expand Down Expand Up @@ -705,6 +750,7 @@ class alignas(8) Operation final
assert(index < getNumSuccessors());
return getBlockOperands()[index].get();
}

void setSuccessor(Block *block, unsigned index);

//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -892,12 +938,14 @@ class alignas(8) Operation final
int getPropertiesStorageSize() const {
return ((int)propertiesStorageSize) * 8;
}

/// Returns the properties storage.
OpaqueProperties getPropertiesStorage() {
if (propertiesStorageSize)
return getPropertiesStorageUnsafe();
return {nullptr};
}

OpaqueProperties getPropertiesStorage() const {
if (propertiesStorageSize)
return {reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
Expand Down Expand Up @@ -933,6 +981,11 @@ class alignas(8) Operation final
/// Compute a hash for the op properties (if any).
llvm::hash_code hashProperties();

bool hasWeakReference() { return blockHasWeakRefPair.getInt(); }
void setHasWeakReference(bool hasWeakRef) {
blockHasWeakRefPair.setInt(hasWeakRef);
}

private:
//===--------------------------------------------------------------------===//
// Ordering
Expand Down Expand Up @@ -1016,7 +1069,7 @@ class alignas(8) Operation final
/// requires a 'getParent() const' method. Once ilist_node removes this
/// constraint, we should drop the const to fit the rest of the MLIR const
/// model.
Block *getParent() const { return block; }
Block *getParent() const { return blockHasWeakRefPair.getPointer(); }

/// Expose a few methods explicitly for the debugger to call for
/// visualization.
Expand All @@ -1031,8 +1084,9 @@ class alignas(8) Operation final
}
#endif

/// The operation block that contains this operation.
Block *block = nullptr;
/// The operation block that contains this operation and a bit that signifies
/// if the operation has a weak reference.
llvm::PointerIntPair<Block *, /*IntBits=*/1, bool> blockHasWeakRefPair;

/// This holds information about the source location the operation was defined
/// or derived from.
Expand Down
39 changes: 39 additions & 0 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ class MLIRContextImpl {
/// destruction.
DistinctAttributeAllocator distinctAttributeAllocator;

llvm::sys::SmartRWMutex<true> weakOperationRefsMutex;
DenseMap<Operation *, std::weak_ptr<WeakOpRefHolder>> weakOperationReferences;

public:
MLIRContextImpl(bool threadingIsEnabled)
: threadingIsEnabled(threadingIsEnabled) {
Expand Down Expand Up @@ -393,6 +396,42 @@ void MLIRContext::executeActionInternal(function_ref<void()> actionFn,

bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }

WeakOpRef MLIRContext::acquireWeakOpRef(Operation *op) {
{
llvm::sys::SmartScopedReader<true> contextLock(
impl->weakOperationRefsMutex);
auto it = impl->weakOperationReferences.find(op);
if (it != impl->weakOperationReferences.end()) {
assert(op->hasWeakReference() &&
"op should report having weak references");
return {it->second.lock()};
}
}
{
ScopedWriterLock contextLock(impl->weakOperationRefsMutex,
isMultithreadingEnabled());
auto shared = std::make_shared<WeakOpRefHolder>(op);
(void)impl->weakOperationReferences.insert({op, shared});
op->setHasWeakReference(true);
return {shared};
}
}

void MLIRContext::expireWeakRefs(Operation *op) {
if (op && impl) {
ScopedWriterLock lock(impl->weakOperationRefsMutex,
isMultithreadingEnabled());
if (auto it = impl->weakOperationReferences.find(op);
it != impl->weakOperationReferences.end()) {
if (!it->second.expired())
it->second.reset();
assert(it->second.expired() && "should be expired");
impl->weakOperationReferences.erase(op);
}
op->setHasWeakReference(false);
}
}

//===----------------------------------------------------------------------===//
// Diagnostic Handlers
//===----------------------------------------------------------------------===//
Expand Down
Loading