diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index 11e5329f43e68..f05f34870764f 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -31,9 +31,11 @@ class DynamicDialect; class InFlightDiagnostic; class Location; class MLIRContextImpl; +class Operation; class RegisteredOperationName; class StorageUniquer; class IRUnit; +class WeakOpRef; /// MLIRContext is the top-level object for a collection of MLIR operations. It /// holds immortal uniqued objects like types, and the tables used to unique @@ -275,6 +277,9 @@ class MLIRContext { actionFn(); } + WeakOpRef acquireWeakOpRef(Operation *op); + void expireWeakRefs(Operation *op); + private: /// Return true if the given dialect is currently loading. bool isDialectLoading(StringRef dialectNamespace); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index f0dd7c5178056..319ae5cfc018e 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -22,6 +22,39 @@ #include namespace mlir { +class WeakOpRef; +class WeakOpRefHolder { +private: + mlir::Operation *op; + +public: + WeakOpRefHolder(mlir::Operation *op) : op(op) {} + ~WeakOpRefHolder(); + friend class WeakOpRef; +}; + +class WeakOpRef { +private: + std::shared_ptr holder; + +public: + WeakOpRef(std::shared_ptr const &r); + + WeakOpRef(WeakOpRef const &r); + WeakOpRef(WeakOpRef &&r); + ~WeakOpRef(); + + WeakOpRef &operator=(WeakOpRef const &r); + WeakOpRef &operator=(WeakOpRef &&r); + + void swap(WeakOpRef &r); + bool expired() const; + long use_count() const { return holder ? holder.use_count() : 0; } + + mlir::Operation *operator->() const; + mlir::Operation &operator*() const; +}; + namespace detail { /// This is a "tag" used for mapping the properties storage in /// llvm::TrailingObjects. @@ -210,7 +243,7 @@ class alignas(8) Operation final Operation *cloneWithoutRegions(); /// Returns the operation block that contains this operation. - Block *getBlock() { return block; } + Block *getBlock() { return blockHasWeakRefPair.getPointer(); } /// Return the context this operation is associated with. MLIRContext *getContext() { return location->getContext(); } @@ -227,11 +260,15 @@ class alignas(8) Operation final /// Returns the region to which the instruction belongs. Returns nullptr if /// the instruction is unlinked. - Region *getParentRegion() { return block ? block->getParent() : nullptr; } + Region *getParentRegion() { + return getBlock() ? getBlock()->getParent() : nullptr; + } /// Returns the closest surrounding operation that contains this operation /// or nullptr if this is a top-level operation. - Operation *getParentOp() { return block ? block->getParentOp() : nullptr; } + Operation *getParentOp() { + return getBlock() ? getBlock()->getParentOp() : nullptr; + } /// Return the closest surrounding parent operation that is of type 'OpTy'. template @@ -545,6 +582,7 @@ class alignas(8) Operation final AttrClass getAttrOfType(StringAttr name) { return llvm::dyn_cast_or_null(getAttr(name)); } + template AttrClass getAttrOfType(StringRef name) { return llvm::dyn_cast_or_null(getAttr(name)); @@ -559,6 +597,7 @@ class alignas(8) Operation final } return attrs.contains(name); } + bool hasAttr(StringRef name) { if (getPropertiesStorageSize()) { if (std::optional inherentAttr = getInherentAttr(name)) @@ -566,6 +605,7 @@ class alignas(8) Operation final } return attrs.contains(name); } + template bool hasAttrOfType(NameT &&name) { return static_cast( @@ -585,6 +625,7 @@ class alignas(8) Operation final if (attributes.set(name, value) != value) attrs = attributes.getDictionary(getContext()); } + void setAttr(StringRef name, Attribute value) { setAttr(StringAttr::get(getContext(), name), value); } @@ -605,6 +646,7 @@ class alignas(8) Operation final attrs = attributes.getDictionary(getContext()); return removedAttr; } + Attribute removeAttr(StringRef name) { return removeAttr(StringAttr::get(getContext(), name)); } @@ -626,6 +668,7 @@ class alignas(8) Operation final // Allow access to the constructor. friend Operation; }; + using dialect_attr_range = iterator_range; /// Return a range corresponding to the dialect attributes for this operation. @@ -634,10 +677,12 @@ class alignas(8) Operation final return {dialect_attr_iterator(attrs.begin(), attrs.end()), dialect_attr_iterator(attrs.end(), attrs.end())}; } + dialect_attr_iterator dialect_attr_begin() { auto attrs = getAttrs(); return dialect_attr_iterator(attrs.begin(), attrs.end()); } + dialect_attr_iterator dialect_attr_end() { auto attrs = getAttrs(); return dialect_attr_iterator(attrs.end(), attrs.end()); @@ -705,6 +750,7 @@ class alignas(8) Operation final assert(index < getNumSuccessors()); return getBlockOperands()[index].get(); } + void setSuccessor(Block *block, unsigned index); //===--------------------------------------------------------------------===// @@ -892,12 +938,14 @@ class alignas(8) Operation final int getPropertiesStorageSize() const { return ((int)propertiesStorageSize) * 8; } + /// Returns the properties storage. OpaqueProperties getPropertiesStorage() { if (propertiesStorageSize) return getPropertiesStorageUnsafe(); return {nullptr}; } + OpaqueProperties getPropertiesStorage() const { if (propertiesStorageSize) return {reinterpret_cast(const_cast( @@ -933,6 +981,11 @@ class alignas(8) Operation final /// Compute a hash for the op properties (if any). llvm::hash_code hashProperties(); + bool hasWeakReference() { return blockHasWeakRefPair.getInt(); } + void setHasWeakReference(bool hasWeakRef) { + blockHasWeakRefPair.setInt(hasWeakRef); + } + private: //===--------------------------------------------------------------------===// // Ordering @@ -1016,7 +1069,7 @@ class alignas(8) Operation final /// requires a 'getParent() const' method. Once ilist_node removes this /// constraint, we should drop the const to fit the rest of the MLIR const /// model. - Block *getParent() const { return block; } + Block *getParent() const { return blockHasWeakRefPair.getPointer(); } /// Expose a few methods explicitly for the debugger to call for /// visualization. @@ -1031,8 +1084,9 @@ class alignas(8) Operation final } #endif - /// The operation block that contains this operation. - Block *block = nullptr; + /// The operation block that contains this operation and a bit that signifies + /// if the operation has a weak reference. + llvm::PointerIntPair blockHasWeakRefPair; /// This holds information about the source location the operation was defined /// or derived from. diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 214b354c5347e..bd466a2cde990 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -271,6 +271,9 @@ class MLIRContextImpl { /// destruction. DistinctAttributeAllocator distinctAttributeAllocator; + llvm::sys::SmartRWMutex weakOperationRefsMutex; + DenseMap> weakOperationReferences; + public: MLIRContextImpl(bool threadingIsEnabled) : threadingIsEnabled(threadingIsEnabled) { @@ -393,6 +396,42 @@ void MLIRContext::executeActionInternal(function_ref actionFn, bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; } +WeakOpRef MLIRContext::acquireWeakOpRef(Operation *op) { + { + llvm::sys::SmartScopedReader contextLock( + impl->weakOperationRefsMutex); + auto it = impl->weakOperationReferences.find(op); + if (it != impl->weakOperationReferences.end()) { + assert(op->hasWeakReference() && + "op should report having weak references"); + return {it->second.lock()}; + } + } + { + ScopedWriterLock contextLock(impl->weakOperationRefsMutex, + isMultithreadingEnabled()); + auto shared = std::make_shared(op); + (void)impl->weakOperationReferences.insert({op, shared}); + op->setHasWeakReference(true); + return {shared}; + } +} + +void MLIRContext::expireWeakRefs(Operation *op) { + if (op && impl) { + ScopedWriterLock lock(impl->weakOperationRefsMutex, + isMultithreadingEnabled()); + if (auto it = impl->weakOperationReferences.find(op); + it != impl->weakOperationReferences.end()) { + if (!it->second.expired()) + it->second.reset(); + assert(it->second.expired() && "should be expired"); + impl->weakOperationReferences.erase(op); + } + op->setHasWeakReference(false); + } +} + //===----------------------------------------------------------------------===// // Diagnostic Handlers //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index b51357198b1ca..dbbd32cffc2bb 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -26,6 +26,42 @@ using namespace mlir; +WeakOpRef::WeakOpRef(const std::shared_ptr &r) : holder(r) {} + +// copy constructor +WeakOpRef::WeakOpRef(const WeakOpRef &r) : holder(r.holder) {} + +// move constructor +WeakOpRef::WeakOpRef(WeakOpRef &&r) : holder(r.holder) { r.holder = nullptr; } + +WeakOpRef::~WeakOpRef() {} + +// copy assignment +WeakOpRef &WeakOpRef::operator=(const WeakOpRef &r) { + WeakOpRef(r).swap(*this); + return *this; +} + +// move assignment +WeakOpRef &WeakOpRef::operator=(WeakOpRef &&r) { + WeakOpRef(std::move(r)).swap(*this); + return *this; +} + +void WeakOpRef::swap(WeakOpRef &r) { std::swap(holder, r.holder); } + +void swap(WeakOpRef &x, WeakOpRef &y) { x.swap(y); } + +Operation *WeakOpRef::operator->() const { return this->holder->op; } + +Operation &WeakOpRef::operator*() const { return *this->holder->op; } + +bool WeakOpRef::expired() const { + return !bool(holder) || holder.use_count() == 0; +} + +WeakOpRefHolder::~WeakOpRefHolder() { op->getContext()->expireWeakRefs(op); } + //===----------------------------------------------------------------------===// // Operation //===----------------------------------------------------------------------===// @@ -177,7 +213,7 @@ Operation::Operation(Location location, OperationName name, unsigned numResults, // Operations are deleted through the destroy() member because they are // allocated via malloc. Operation::~Operation() { - assert(block == nullptr && "operation destroyed but still in a block"); + assert(getBlock() == nullptr && "operation destroyed but still in a block"); #ifndef NDEBUG if (!use_empty()) { { @@ -202,6 +238,9 @@ Operation::~Operation() { region.~Region(); if (propertiesStorageSize) name.destroyOpProperties(getPropertiesStorage()); + + if (hasWeakReference()) + getContext()->expireWeakRefs(this); } /// Destroy this operation or one of its subclasses. @@ -322,8 +361,8 @@ void Operation::setAttrs(DictionaryAttr newAttrs) { } void Operation::setAttrs(ArrayRef newAttrs) { if (getPropertiesStorageSize()) { - // We're spliting the providing array of attributes by removing the inherentAttr - // which will be stored in the properties. + // We're spliting the providing array of attributes by removing the + // inherentAttr which will be stored in the properties. SmallVector discardableAttrs; discardableAttrs.reserve(newAttrs.size()); for (NamedAttribute attr : newAttrs) { @@ -384,13 +423,13 @@ constexpr unsigned Operation::kOrderStride; /// Note: This function has an average complexity of O(1), but worst case may /// take O(N) where N is the number of operations within the parent block. bool Operation::isBeforeInBlock(Operation *other) { - assert(block && "Operations without parent blocks have no order."); - assert(other && other->block == block && + assert(getBlock() && "Operations without parent blocks have no order."); + assert(other && other->getBlock() == getBlock() && "Expected other operation to have the same parent block."); // If the order of the block is already invalid, directly recompute the // parent. - if (!block->isOpOrderValid()) { - block->recomputeOpOrder(); + if (!getBlock()->isOpOrderValid()) { + getBlock()->recomputeOpOrder(); } else { // Update the order either operation if necessary. updateOrderIfNecessary(); @@ -403,13 +442,13 @@ bool Operation::isBeforeInBlock(Operation *other) { /// Update the order index of this operation of this operation if necessary, /// potentially recomputing the order of the parent block. void Operation::updateOrderIfNecessary() { - assert(block && "expected valid parent"); + assert(getBlock() && "expected valid parent"); // If the order is valid for this operation there is nothing to do. if (hasValidOrder()) return; - Operation *blockFront = &block->front(); - Operation *blockBack = &block->back(); + Operation *blockFront = &getBlock()->front(); + Operation *blockBack = &getBlock()->back(); // This method is expected to only be invoked on blocks with more than one // operation. @@ -419,7 +458,7 @@ void Operation::updateOrderIfNecessary() { if (this == blockBack) { Operation *prevNode = getPrevNode(); if (!prevNode->hasValidOrder()) - return block->recomputeOpOrder(); + return getBlock()->recomputeOpOrder(); // Add the stride to the previous operation. orderIndex = prevNode->orderIndex + kOrderStride; @@ -431,10 +470,10 @@ void Operation::updateOrderIfNecessary() { if (this == blockFront) { Operation *nextNode = getNextNode(); if (!nextNode->hasValidOrder()) - return block->recomputeOpOrder(); + return getBlock()->recomputeOpOrder(); // There is no order to give this operation. if (nextNode->orderIndex == 0) - return block->recomputeOpOrder(); + return getBlock()->recomputeOpOrder(); // If we can't use the stride, just take the middle value left. This is safe // because we know there is at least one valid index to assign to. @@ -449,12 +488,12 @@ void Operation::updateOrderIfNecessary() { // the middle of the previous and next if possible. Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) - return block->recomputeOpOrder(); + return getBlock()->recomputeOpOrder(); unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; // Check to see if there is a valid order between the two. if (prevOrder + 1 == nextOrder) - return block->recomputeOpOrder(); + return getBlock()->recomputeOpOrder(); orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); } @@ -502,7 +541,7 @@ Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { /// keep the block pointer up to date. void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { assert(!op->getBlock() && "already in an operation block!"); - op->block = getContainingBlock(); + op->blockHasWeakRefPair.setPointer(getContainingBlock()); // Invalidate the order on the operation. op->orderIndex = Operation::kInvalidOrderIdx; @@ -511,8 +550,8 @@ void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { /// This is a trait method invoked when an operation is removed from a block. /// We keep the block pointer up to date. void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { - assert(op->block && "not already in an operation block!"); - op->block = nullptr; + assert(op->getBlock() && "not already in an operation block!"); + op->blockHasWeakRefPair.setPointer(nullptr); } /// This is a trait method invoked when an operation is moved from one block @@ -531,7 +570,7 @@ void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( // Update the 'block' member of each operation. for (; first != last; ++first) - first->block = curParent; + first->blockHasWeakRefPair.setPointer(curParent); } /// Remove this operation (and its descendants) from its Block and delete diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index f94dc78445807..e5b48851cc48e 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -313,4 +313,25 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) { op2->destroy(); } +TEST(WeakOpRefTest, Test1) { + MLIRContext context; + context.getOrLoadDialect(); + + auto *op1 = createOp(&context); + EXPECT_EQ(op1->hasWeakReference(), false); + { + WeakOpRef weakRef1 = context.acquireWeakOpRef(op1); + EXPECT_EQ(weakRef1.use_count(), 1); + EXPECT_EQ(op1->hasWeakReference(), true); + { + WeakOpRef weakRef2 = context.acquireWeakOpRef(op1); + EXPECT_EQ(weakRef2.use_count(), 2); + EXPECT_EQ(op1->hasWeakReference(), true); + } + EXPECT_EQ(weakRef1.use_count(), 1); + EXPECT_EQ(op1->hasWeakReference(), true); + } + EXPECT_EQ(op1->hasWeakReference(), false); +} + } // namespace