Skip to content

Commit 53692ef

Browse files
committed
[mlir][operation] weak refs
1 parent 56a636f commit 53692ef

File tree

5 files changed

+183
-25
lines changed

5 files changed

+183
-25
lines changed

mlir/include/mlir/IR/MLIRContext.h

+5
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ class DynamicDialect;
3131
class InFlightDiagnostic;
3232
class Location;
3333
class MLIRContextImpl;
34+
class Operation;
3435
class RegisteredOperationName;
3536
class StorageUniquer;
3637
class IRUnit;
38+
class WeakOpRef;
3739

3840
/// MLIRContext is the top-level object for a collection of MLIR operations. It
3941
/// holds immortal uniqued objects like types, and the tables used to unique
@@ -275,6 +277,9 @@ class MLIRContext {
275277
actionFn();
276278
}
277279

280+
WeakOpRef acquireWeakOpRef(Operation *op);
281+
void expireWeakRefs(Operation *op);
282+
278283
private:
279284
/// Return true if the given dialect is currently loading.
280285
bool isDialectLoading(StringRef dialectNamespace);

mlir/include/mlir/IR/Operation.h

+60-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,39 @@
2222
#include <optional>
2323

2424
namespace mlir {
25+
class WeakOpRef;
26+
class WeakOpRefHolder {
27+
private:
28+
mlir::Operation *op;
29+
30+
public:
31+
WeakOpRefHolder(mlir::Operation *op) : op(op) {}
32+
~WeakOpRefHolder();
33+
friend class WeakOpRef;
34+
};
35+
36+
class WeakOpRef {
37+
private:
38+
std::shared_ptr<WeakOpRefHolder> holder;
39+
40+
public:
41+
WeakOpRef(std::shared_ptr<WeakOpRefHolder> const &r);
42+
43+
WeakOpRef(WeakOpRef const &r);
44+
WeakOpRef(WeakOpRef &&r);
45+
~WeakOpRef();
46+
47+
WeakOpRef &operator=(WeakOpRef const &r);
48+
WeakOpRef &operator=(WeakOpRef &&r);
49+
50+
void swap(WeakOpRef &r);
51+
bool expired() const;
52+
long use_count() const { return holder ? holder.use_count() : 0; }
53+
54+
mlir::Operation *operator->() const;
55+
mlir::Operation &operator*() const;
56+
};
57+
2558
namespace detail {
2659
/// This is a "tag" used for mapping the properties storage in
2760
/// llvm::TrailingObjects.
@@ -210,7 +243,7 @@ class alignas(8) Operation final
210243
Operation *cloneWithoutRegions();
211244

212245
/// Returns the operation block that contains this operation.
213-
Block *getBlock() { return block; }
246+
Block *getBlock() { return blockHasWeakRefPair.getPointer(); }
214247

215248
/// Return the context this operation is associated with.
216249
MLIRContext *getContext() { return location->getContext(); }
@@ -227,11 +260,15 @@ class alignas(8) Operation final
227260

228261
/// Returns the region to which the instruction belongs. Returns nullptr if
229262
/// the instruction is unlinked.
230-
Region *getParentRegion() { return block ? block->getParent() : nullptr; }
263+
Region *getParentRegion() {
264+
return getBlock() ? getBlock()->getParent() : nullptr;
265+
}
231266

232267
/// Returns the closest surrounding operation that contains this operation
233268
/// or nullptr if this is a top-level operation.
234-
Operation *getParentOp() { return block ? block->getParentOp() : nullptr; }
269+
Operation *getParentOp() {
270+
return getBlock() ? getBlock()->getParentOp() : nullptr;
271+
}
235272

236273
/// Return the closest surrounding parent operation that is of type 'OpTy'.
237274
template <typename OpTy>
@@ -545,6 +582,7 @@ class alignas(8) Operation final
545582
AttrClass getAttrOfType(StringAttr name) {
546583
return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
547584
}
585+
548586
template <typename AttrClass>
549587
AttrClass getAttrOfType(StringRef name) {
550588
return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
@@ -559,13 +597,15 @@ class alignas(8) Operation final
559597
}
560598
return attrs.contains(name);
561599
}
600+
562601
bool hasAttr(StringRef name) {
563602
if (getPropertiesStorageSize()) {
564603
if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
565604
return (bool)*inherentAttr;
566605
}
567606
return attrs.contains(name);
568607
}
608+
569609
template <typename AttrClass, typename NameT>
570610
bool hasAttrOfType(NameT &&name) {
571611
return static_cast<bool>(
@@ -585,6 +625,7 @@ class alignas(8) Operation final
585625
if (attributes.set(name, value) != value)
586626
attrs = attributes.getDictionary(getContext());
587627
}
628+
588629
void setAttr(StringRef name, Attribute value) {
589630
setAttr(StringAttr::get(getContext(), name), value);
590631
}
@@ -605,6 +646,7 @@ class alignas(8) Operation final
605646
attrs = attributes.getDictionary(getContext());
606647
return removedAttr;
607648
}
649+
608650
Attribute removeAttr(StringRef name) {
609651
return removeAttr(StringAttr::get(getContext(), name));
610652
}
@@ -626,6 +668,7 @@ class alignas(8) Operation final
626668
// Allow access to the constructor.
627669
friend Operation;
628670
};
671+
629672
using dialect_attr_range = iterator_range<dialect_attr_iterator>;
630673

631674
/// Return a range corresponding to the dialect attributes for this operation.
@@ -634,10 +677,12 @@ class alignas(8) Operation final
634677
return {dialect_attr_iterator(attrs.begin(), attrs.end()),
635678
dialect_attr_iterator(attrs.end(), attrs.end())};
636679
}
680+
637681
dialect_attr_iterator dialect_attr_begin() {
638682
auto attrs = getAttrs();
639683
return dialect_attr_iterator(attrs.begin(), attrs.end());
640684
}
685+
641686
dialect_attr_iterator dialect_attr_end() {
642687
auto attrs = getAttrs();
643688
return dialect_attr_iterator(attrs.end(), attrs.end());
@@ -705,6 +750,7 @@ class alignas(8) Operation final
705750
assert(index < getNumSuccessors());
706751
return getBlockOperands()[index].get();
707752
}
753+
708754
void setSuccessor(Block *block, unsigned index);
709755

710756
//===--------------------------------------------------------------------===//
@@ -892,12 +938,14 @@ class alignas(8) Operation final
892938
int getPropertiesStorageSize() const {
893939
return ((int)propertiesStorageSize) * 8;
894940
}
941+
895942
/// Returns the properties storage.
896943
OpaqueProperties getPropertiesStorage() {
897944
if (propertiesStorageSize)
898945
return getPropertiesStorageUnsafe();
899946
return {nullptr};
900947
}
948+
901949
OpaqueProperties getPropertiesStorage() const {
902950
if (propertiesStorageSize)
903951
return {reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
@@ -933,6 +981,11 @@ class alignas(8) Operation final
933981
/// Compute a hash for the op properties (if any).
934982
llvm::hash_code hashProperties();
935983

984+
bool hasWeakReference() { return blockHasWeakRefPair.getInt(); }
985+
void setHasWeakReference(bool hasWeakRef) {
986+
blockHasWeakRefPair.setInt(hasWeakRef);
987+
}
988+
936989
private:
937990
//===--------------------------------------------------------------------===//
938991
// Ordering
@@ -1016,7 +1069,7 @@ class alignas(8) Operation final
10161069
/// requires a 'getParent() const' method. Once ilist_node removes this
10171070
/// constraint, we should drop the const to fit the rest of the MLIR const
10181071
/// model.
1019-
Block *getParent() const { return block; }
1072+
Block *getParent() const { return blockHasWeakRefPair.getPointer(); }
10201073

10211074
/// Expose a few methods explicitly for the debugger to call for
10221075
/// visualization.
@@ -1031,8 +1084,9 @@ class alignas(8) Operation final
10311084
}
10321085
#endif
10331086

1034-
/// The operation block that contains this operation.
1035-
Block *block = nullptr;
1087+
/// The operation block that contains this operation and a bit that signifies
1088+
/// if the operation has a weak reference.
1089+
llvm::PointerIntPair<Block *, /*IntBits=*/1, bool> blockHasWeakRefPair;
10361090

10371091
/// This holds information about the source location the operation was defined
10381092
/// or derived from.

mlir/lib/IR/MLIRContext.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ class MLIRContextImpl {
271271
/// destruction.
272272
DistinctAttributeAllocator distinctAttributeAllocator;
273273

274+
llvm::sys::SmartRWMutex<true> weakOperationRefsMutex;
275+
DenseMap<Operation *, std::weak_ptr<WeakOpRefHolder>> weakOperationReferences;
276+
274277
public:
275278
MLIRContextImpl(bool threadingIsEnabled)
276279
: threadingIsEnabled(threadingIsEnabled) {
@@ -393,6 +396,42 @@ void MLIRContext::executeActionInternal(function_ref<void()> actionFn,
393396

394397
bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }
395398

399+
WeakOpRef MLIRContext::acquireWeakOpRef(Operation *op) {
400+
{
401+
llvm::sys::SmartScopedReader<true> contextLock(
402+
impl->weakOperationRefsMutex);
403+
auto it = impl->weakOperationReferences.find(op);
404+
if (it != impl->weakOperationReferences.end()) {
405+
assert(op->hasWeakReference() &&
406+
"op should report having weak references");
407+
return {it->second.lock()};
408+
}
409+
}
410+
{
411+
ScopedWriterLock contextLock(impl->weakOperationRefsMutex,
412+
isMultithreadingEnabled());
413+
auto shared = std::make_shared<WeakOpRefHolder>(op);
414+
(void)impl->weakOperationReferences.insert({op, shared});
415+
op->setHasWeakReference(true);
416+
return {shared};
417+
}
418+
}
419+
420+
void MLIRContext::expireWeakRefs(Operation *op) {
421+
if (op && impl) {
422+
ScopedWriterLock lock(impl->weakOperationRefsMutex,
423+
isMultithreadingEnabled());
424+
if (auto it = impl->weakOperationReferences.find(op);
425+
it != impl->weakOperationReferences.end()) {
426+
if (!it->second.expired())
427+
it->second.reset();
428+
assert(it->second.expired() && "should be expired");
429+
impl->weakOperationReferences.erase(op);
430+
}
431+
op->setHasWeakReference(false);
432+
}
433+
}
434+
396435
//===----------------------------------------------------------------------===//
397436
// Diagnostic Handlers
398437
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)