Skip to content

Commit c00773a

Browse files
committed
[mlir][python] wip remove liveOpeartions
1 parent 1e9324a commit c00773a

File tree

11 files changed

+174
-220
lines changed

11 files changed

+174
-220
lines changed

mlir/include/mlir-c/IR.h

+2
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
323323
/// The returned module is null when the input operation was not a ModuleOp.
324324
MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
325325

326+
MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule mod, MlirModule other);
327+
326328
//===----------------------------------------------------------------------===//
327329
// Operation state.
328330
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

+22-116
Original file line numberDiff line numberDiff line change
@@ -634,58 +634,6 @@ PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
634634

635635
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
636636

637-
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
638-
639-
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
640-
std::vector<PyOperation *> liveObjects;
641-
for (auto &entry : liveOperations)
642-
liveObjects.push_back(entry.second.second);
643-
return liveObjects;
644-
}
645-
646-
size_t PyMlirContext::clearLiveOperations() {
647-
for (auto &op : liveOperations)
648-
op.second.second->setInvalid();
649-
size_t numInvalidated = liveOperations.size();
650-
liveOperations.clear();
651-
return numInvalidated;
652-
}
653-
654-
void PyMlirContext::clearOperation(MlirOperation op) {
655-
auto it = liveOperations.find(op.ptr);
656-
if (it != liveOperations.end()) {
657-
it->second.second->setInvalid();
658-
liveOperations.erase(it);
659-
}
660-
}
661-
662-
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
663-
typedef struct {
664-
PyOperation &rootOp;
665-
bool rootSeen;
666-
} callBackData;
667-
callBackData data{op.getOperation(), false};
668-
// Mark all ops below the op that the passmanager will be rooted
669-
// at (but not op itself - note the preorder) as invalid.
670-
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
671-
void *userData) {
672-
callBackData *data = static_cast<callBackData *>(userData);
673-
if (LLVM_LIKELY(data->rootSeen))
674-
data->rootOp.getOperation().getContext()->clearOperation(op);
675-
else
676-
data->rootSeen = true;
677-
return MlirWalkResult::MlirWalkResultAdvance;
678-
};
679-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
680-
static_cast<void *>(&data), MlirWalkPreOrder);
681-
}
682-
void PyMlirContext::clearOperationsInside(MlirOperation op) {
683-
PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
684-
clearOperationsInside(opRef->getOperation());
685-
}
686-
687-
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
688-
689637
pybind11::object PyMlirContext::contextEnter() {
690638
return PyThreadContextEntry::pushContext(*this);
691639
}
@@ -1055,39 +1003,21 @@ PyLocation &DefaultingPyLocation::resolve() {
10551003
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
10561004
: BaseContextObject(std::move(contextRef)), module(module) {}
10571005

1058-
PyModule::~PyModule() {
1059-
py::gil_scoped_acquire acquire;
1060-
auto &liveModules = getContext()->liveModules;
1061-
assert(liveModules.count(module.ptr) == 1 &&
1062-
"destroying module not in live map");
1063-
liveModules.erase(module.ptr);
1064-
mlirModuleDestroy(module);
1065-
}
1006+
PyModule::~PyModule() { mlirModuleDestroy(module); }
10661007

10671008
PyModuleRef PyModule::forModule(MlirModule module) {
10681009
MlirContext context = mlirModuleGetContext(module);
10691010
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
10701011

1071-
py::gil_scoped_acquire acquire;
1072-
auto &liveModules = contextRef->liveModules;
1073-
auto it = liveModules.find(module.ptr);
1074-
if (it == liveModules.end()) {
1075-
// Create.
1076-
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1077-
// Note that the default return value policy on cast is automatic_reference,
1078-
// which does not take ownership (delete will not be called).
1079-
// Just be explicit.
1080-
py::object pyRef =
1081-
py::cast(unownedModule, py::return_value_policy::take_ownership);
1082-
unownedModule->handle = pyRef;
1083-
liveModules[module.ptr] =
1084-
std::make_pair(unownedModule->handle, unownedModule);
1085-
return PyModuleRef(unownedModule, std::move(pyRef));
1086-
}
1087-
// Use existing.
1088-
PyModule *existing = it->second.second;
1089-
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1090-
return PyModuleRef(existing, std::move(pyRef));
1012+
// Create.
1013+
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1014+
// Note that the default return value policy on cast is automatic_reference,
1015+
// which does not take ownership (delete will not be called).
1016+
// Just be explicit.
1017+
py::object pyRef =
1018+
py::cast(unownedModule, py::return_value_policy::take_ownership);
1019+
unownedModule->handle = pyRef;
1020+
return PyModuleRef(unownedModule, std::move(pyRef));
10911021
}
10921022

10931023
py::object PyModule::createFromCapsule(py::object capsule) {
@@ -1112,10 +1042,6 @@ PyOperation::~PyOperation() {
11121042
// If the operation has already been invalidated there is nothing to do.
11131043
if (!valid)
11141044
return;
1115-
auto &liveOperations = getContext()->liveOperations;
1116-
assert(liveOperations.count(operation.ptr) == 1 &&
1117-
"destroying operation not in live map");
1118-
liveOperations.erase(operation.ptr);
11191045
if (!isAttached()) {
11201046
mlirOperationDestroy(operation);
11211047
}
@@ -1124,7 +1050,6 @@ PyOperation::~PyOperation() {
11241050
PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
11251051
MlirOperation operation,
11261052
py::object parentKeepAlive) {
1127-
auto &liveOperations = contextRef->liveOperations;
11281053
// Create.
11291054
PyOperation *unownedOperation =
11301055
new PyOperation(std::move(contextRef), operation);
@@ -1137,34 +1062,20 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
11371062
if (parentKeepAlive) {
11381063
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
11391064
}
1140-
liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
11411065
return PyOperationRef(unownedOperation, std::move(pyRef));
11421066
}
11431067

11441068
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
11451069
MlirOperation operation,
11461070
py::object parentKeepAlive) {
1147-
auto &liveOperations = contextRef->liveOperations;
1148-
auto it = liveOperations.find(operation.ptr);
1149-
if (it == liveOperations.end()) {
1150-
// Create.
1151-
return createInstance(std::move(contextRef), operation,
1152-
std::move(parentKeepAlive));
1153-
}
1154-
// Use existing.
1155-
PyOperation *existing = it->second.second;
1156-
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1157-
return PyOperationRef(existing, std::move(pyRef));
1071+
// Create.
1072+
return createInstance(std::move(contextRef), operation,
1073+
std::move(parentKeepAlive));
11581074
}
11591075

11601076
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
11611077
MlirOperation operation,
11621078
py::object parentKeepAlive) {
1163-
auto &liveOperations = contextRef->liveOperations;
1164-
assert(liveOperations.count(operation.ptr) == 0 &&
1165-
"cannot create detached operation that already exists");
1166-
(void)liveOperations;
1167-
11681079
PyOperationRef created = createInstance(std::move(contextRef), operation,
11691080
std::move(parentKeepAlive));
11701081
created->attached = false;
@@ -1530,9 +1441,6 @@ void PyOperation::erase() {
15301441
// TODO: Fix memory hazards when erasing a tree of operations for which a deep
15311442
// Python reference to a child operation is live. All children should also
15321443
// have their `valid` bit set to false.
1533-
auto &liveOperations = getContext()->liveOperations;
1534-
if (liveOperations.count(operation.ptr))
1535-
liveOperations.erase(operation.ptr);
15361444
mlirOperationDestroy(operation);
15371445
valid = false;
15381446
}
@@ -2274,7 +2182,6 @@ class PyBlockArgumentList
22742182
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
22752183
public:
22762184
static constexpr const char *pyClassName = "BlockArgumentList";
2277-
using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
22782185

22792186
PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
22802187
intptr_t startIndex = 0, intptr_t length = -1,
@@ -2598,14 +2505,6 @@ void mlir::python::populateIRCore(py::module &m) {
25982505
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
25992506
return ref.releaseObject();
26002507
})
2601-
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2602-
.def("_get_live_operation_objects",
2603-
&PyMlirContext::getLiveOperationObjects)
2604-
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2605-
.def("_clear_live_operations_inside",
2606-
py::overload_cast<MlirOperation>(
2607-
&PyMlirContext::clearOperationsInside))
2608-
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
26092508
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
26102509
&PyMlirContext::getCapsule)
26112510
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
@@ -2915,7 +2814,13 @@ void mlir::python::populateIRCore(py::module &m) {
29152814
// Defer to the operation's __str__.
29162815
return self.attr("operation").attr("__str__")();
29172816
},
2918-
kOperationStrDunderDocstring);
2817+
kOperationStrDunderDocstring)
2818+
.def(
2819+
"__eq__",
2820+
[](PyModule &self, PyModule &other) {
2821+
return mlirModuleEqual(self.get(), other.get());
2822+
},
2823+
"other"_a);
29192824

29202825
//----------------------------------------------------------------------------
29212826
// Mapping of Operation.
@@ -2927,7 +2832,8 @@ void mlir::python::populateIRCore(py::module &m) {
29272832
})
29282833
.def("__eq__",
29292834
[](PyOperationBase &self, PyOperationBase &other) {
2930-
return &self.getOperation() == &other.getOperation();
2835+
return mlirOperationEqual(self.getOperation().get(),
2836+
other.getOperation().get());
29312837
})
29322838
.def("__eq__",
29332839
[](PyOperationBase &self, py::object other) { return false; })

mlir/lib/Bindings/Python/IRModule.h

-44
Original file line numberDiff line numberDiff line change
@@ -201,34 +201,6 @@ class PyMlirContext {
201201
/// Gets the count of live context objects. Used for testing.
202202
static size_t getLiveCount();
203203

204-
/// Get a list of Python objects which are still in the live context map.
205-
std::vector<PyOperation *> getLiveOperationObjects();
206-
207-
/// Gets the count of live operations associated with this context.
208-
/// Used for testing.
209-
size_t getLiveOperationCount();
210-
211-
/// Clears the live operations map, returning the number of entries which were
212-
/// invalidated. To be used as a safety mechanism so that API end-users can't
213-
/// corrupt by holding references they shouldn't have accessed in the first
214-
/// place.
215-
size_t clearLiveOperations();
216-
217-
/// Removes an operation from the live operations map and sets it invalid.
218-
/// This is useful for when some non-bindings code destroys the operation and
219-
/// the bindings need to made aware. For example, in the case when pass
220-
/// manager is run.
221-
void clearOperation(MlirOperation op);
222-
223-
/// Clears all operations nested inside the given op using
224-
/// `clearOperation(MlirOperation)`.
225-
void clearOperationsInside(PyOperationBase &op);
226-
void clearOperationsInside(MlirOperation op);
227-
228-
/// Gets the count of live modules associated with this context.
229-
/// Used for testing.
230-
size_t getLiveModuleCount();
231-
232204
/// Enter and exit the context manager.
233205
pybind11::object contextEnter();
234206
void contextExit(const pybind11::object &excType,
@@ -255,22 +227,6 @@ class PyMlirContext {
255227
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
256228
static LiveContextMap &getLiveContexts();
257229

258-
// Interns all live modules associated with this context. Modules tracked
259-
// in this map are valid. When a module is invalidated, it is removed
260-
// from this map, and while it still exists as an instance, any
261-
// attempt to access it will raise an error.
262-
using LiveModuleMap =
263-
llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
264-
LiveModuleMap liveModules;
265-
266-
// Interns all live operations associated with this context. Operations
267-
// tracked in this map are valid. When an operation is invalidated, it is
268-
// removed from this map, and while it still exists as an instance, any
269-
// attempt to access it will raise an error.
270-
using LiveOperationMap =
271-
llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
272-
LiveOperationMap liveOperations;
273-
274230
bool emitErrorDiagnostics = false;
275231

276232
MlirContext context;

mlir/lib/Bindings/Python/Pass.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
117117
"ValueError if the pipeline can't be parsed.")
118118
.def(
119119
"run",
120-
[](PyPassManager &passManager, PyOperationBase &op,
121-
bool invalidateOps) {
122-
if (invalidateOps) {
123-
op.getOperation().getContext()->clearOperationsInside(op);
124-
}
120+
[](PyPassManager &passManager, PyOperationBase &op) {
125121
// Actually run the pass manager.
126122
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
127123
MlirLogicalResult status = mlirPassManagerRunOnOp(
@@ -130,7 +126,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
130126
throw MLIRError("Failure while executing pass pipeline",
131127
errors.take());
132128
},
133-
"operation"_a, "invalidate_ops"_a = true,
129+
"operation"_a,
134130
"Run the pass manager on the provided operation, raising an "
135131
"MLIRError on failure.")
136132
.def(

mlir/lib/Bindings/Python/TransformInterpreter.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
6868
// root. This is awkward, but we don't have access to PyMlirContext
6969
// object here otherwise.
7070
py::object obj = py::cast(payloadRoot);
71-
obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
7271

7372
MlirLogicalResult result = mlirTransformApplyNamedSequence(
7473
payloadRoot, transformRoot, transformModule, options.options);

mlir/lib/CAPI/IR/IR.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
332332
return wrap(dyn_cast<ModuleOp>(unwrap(op)));
333333
}
334334

335+
bool mlirModuleEqual(MlirModule mod, MlirModule other) {
336+
return unwrap(mod) == unwrap(other);
337+
}
338+
335339
//===----------------------------------------------------------------------===//
336340
// Operation state API.
337341
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)