Skip to content

Added free-threading CPython mode support in MLIR Python bindings #107103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion mlir/cmake/modules/AddMLIRPython.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,31 @@ function(add_mlir_python_extension libname extname)
elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
nanobind_add_module(${libname}
NB_DOMAIN mlir
FREE_THREADED
${ARG_SOURCES}
)

if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
# Avoids warnings from upstream nanobind.
target_compile_options(nanobind-static
set(nanobind_target "nanobind-static")
if (NOT TARGET ${nanobind_target})
# Get correct nanobind target name: nanobind-static-ft or something else
# It is set by nanobind_add_module function according to the passed options
get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)
Comment on lines +677 to +681
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing issues with this code when building in the downstream https://github.com/iree-org/iree project, using Python 3.13t. In my project, the add_mlir_python_extension function here is running from a different CMake source directory that does not include any nanobind targets, as they are defined in other source files:

Logs with some extra debugging messages:

  --   Building iree-dialects project at /home/nod/dev/projects/iree/llvm-external-projects/iree-dialects (into /home/nod/dev/projects/iree/compiler/build/b/llvm-external-projects/mlir-iree-dialects)
  --   libname IREEDialectsPythonModules.extension._mlirRegisterEverything.dso using 'nanobind'
  --   all_targets in /home/nod/dev/projects/iree/llvm-external-projects/iree-dialects/python: IREEDialectsPythonSources;IREEDialectsPythonExtensions;IREEDialectsPythonSources.Dialects;IREEDialectsPythonSources.Dialects.iree_input;IREEDialectsPythonSources.Dialects.iree_input.tablegen;IREEDialectsPythonSources.Dialects.iree_input.ops_gen;IREEDialectsPythonSources.Dialects.iree_structured_transform;IREEDialectsPythonSources.Dialects.iree_structured_transform.tablegen;IREEDialectsPythonSources.Dialects.iree_structured_transform.ops_gen;IREEDialectsPythonExtensions.Main;IREEDialectsAggregateCAPI;IREEDialectsAggregateCAPI.sources;IREEDialectsPythonModules;IREEDialectsPythonModules.extension._mlirRegisterEverything.dso
  CMake Error at compiler/build/b/lib/cmake/mlir/AddMLIRPython.cmake:696 (message):
    Could not find nanobind target to set compile options to
  Call Stack (most recent call first):
    compiler/build/b/lib/cmake/mlir/AddMLIRPython.cmake:235 (add_mlir_python_extension)
    compiler/build/b/lib/cmake/mlir/AddMLIRPython.cmake:256 (_process_target)
    llvm-external-projects/iree-dialects/python/CMakeLists.txt:82 (add_mlir_python_modules)

full logs from CI: https://github.com/iree-org/iree/actions/runs/12901589475/job/35973909192#step:12:53035

The target name with 3.13t is nanobind-static-ft, but this search misses it because of the different CMake source.


I can think of a few solutions, if this target lookup and options setting is truly required:

  1. Search through an explicit list of target names using if (TARGET)
  2. Iterate over all targets in the project, not just the targets in the current directory, using code like https://stackoverflow.com/a/62311397
  3. Iterate over targets in the directory known to MLIR (llvm-project/mlir/python)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My downstream project is also setting MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES to take control over the Python dependency setup. Another solution is to only modify the target options if MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES is not set.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMake is mysterious to me, but would wjakob/nanobind#868 help here? All we're trying to do here is suppress the warnings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe... I don't see warnings from nanobind when building in IREE, but IREE does set its own compile options and it imports using FetchContent instead of find_package.

Proposing a change/fix: #123997


# Iterate over the list of targets
foreach(target ${all_targets})
# Check if the target name matches the given string
if("${target}" MATCHES "nanobind-")
set(nanobind_target "${target}")
endif()
endforeach()

if (NOT TARGET ${nanobind_target})
message(FATAL_ERROR "Could not find nanobind target to set compile options to")
endif()
endif()
target_compile_options(${nanobind_target}
PRIVATE
-Wno-cast-qual
-Wno-zero-length-array
Expand Down
40 changes: 40 additions & 0 deletions mlir/docs/Bindings/Python.md
Original file line number Diff line number Diff line change
Expand Up @@ -1187,3 +1187,43 @@ or nanobind and
utilities to connect to the rest of Python API. The bindings can be located in a
separate module or in the same module as attributes and types, and
loaded along with the dialect.

## Free-threading (No-GIL) support

Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).

MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:

```python
# python3.13t example.py
import concurrent.futures

import mlir.dialects.arith as arith
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint


def func(py_value):
with Context() as ctx:
module = Module.create(loc=Location.file("foo.txt", 0, 0))

dtype = IntegerType.get_signless(64)
with InsertionPoint(module.body), Location.name("a"):
arith.constant(dtype, py_value)

return module


num_workers = 8
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for i in range(num_workers):
futures.append(executor.submit(func, i))
assert len(list(f.result() for f in futures)) == num_workers
```

The exceptions to the free-threading compatibility:
- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
12 changes: 11 additions & 1 deletion mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace mlir {
namespace python {

/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
class PyGlobals {
public:
PyGlobals();
Expand All @@ -37,12 +38,18 @@ class PyGlobals {

/// Get and set the list of parent modules to search for dialect
/// implementation classes.
std::vector<std::string> &getDialectSearchPrefixes() {
std::vector<std::string> getDialectSearchPrefixes() {
nanobind::ft_lock_guard lock(mutex);
return dialectSearchPrefixes;
}
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
nanobind::ft_lock_guard lock(mutex);
dialectSearchPrefixes.swap(newValues);
}
void addDialectSearchPrefix(std::string value) {
nanobind::ft_lock_guard lock(mutex);
dialectSearchPrefixes.push_back(std::move(value));
}

/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
Expand Down Expand Up @@ -109,6 +116,9 @@ class PyGlobals {

private:
static PyGlobals *instance;

nanobind::ft_mutex mutex;

/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
Expand Down
31 changes: 27 additions & 4 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,15 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,

/// Wrapper for the global LLVM debugging flag.
struct PyGlobalDebugFlag {
static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
static void set(nb::object &o, bool enable) {
nb::ft_lock_guard lock(mutex);
mlirEnableGlobalDebug(enable);
}

static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
static bool get(const nb::object &) {
nb::ft_lock_guard lock(mutex);
return mlirIsGlobalDebugEnabled();
}

static void bind(nb::module_ &m) {
// Debug flags.
Expand All @@ -255,6 +261,7 @@ struct PyGlobalDebugFlag {
.def_static(
"set_types",
[](const std::string &type) {
nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugType(type.c_str());
},
"types"_a, "Sets specific debug types to be produced by LLVM")
Expand All @@ -263,11 +270,17 @@ struct PyGlobalDebugFlag {
pointers.reserve(types.size());
for (const std::string &str : types)
pointers.push_back(str.c_str());
nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
});
}

private:
static nb::ft_mutex mutex;
};

nb::ft_mutex PyGlobalDebugFlag::mutex;

struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
Expand Down Expand Up @@ -606,6 +619,7 @@ class PyOpOperandIterator {

PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
nb::gil_scoped_acquire acquire;
nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
liveContexts[context.ptr] = this;
}
Expand All @@ -615,7 +629,10 @@ PyMlirContext::~PyMlirContext() {
// forContext method, which always puts the associated handle into
// liveContexts.
nb::gil_scoped_acquire acquire;
getLiveContexts().erase(context.ptr);
{
nb::ft_lock_guard lock(live_contexts_mutex);
getLiveContexts().erase(context.ptr);
}
mlirContextDestroy(context);
}

Expand All @@ -632,6 +649,7 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {

PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
nb::gil_scoped_acquire acquire;
nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
Expand All @@ -647,12 +665,17 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
return PyMlirContextRef(it->second, std::move(pyRef));
}

nb::ft_mutex PyMlirContext::live_contexts_mutex;

PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
static LiveContextMap liveContexts;
return liveContexts;
}

size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
size_t PyMlirContext::getLiveCount() {
nb::ft_lock_guard lock(live_contexts_mutex);
return getLiveContexts().size();
}

size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }

Expand Down
18 changes: 16 additions & 2 deletions mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }

bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
if (loadedDialectModules.contains(dialectNamespace))
return true;
{
nb::ft_lock_guard lock(mutex);
if (loadedDialectModules.contains(dialectNamespace))
return true;
}
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
nb::object loaded = nb::none();
Expand All @@ -62,12 +65,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
return false;
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
nb::ft_lock_guard lock(mutex);
loadedDialectModules.insert(dialectNamespace);
return true;
}

void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
nb::callable pyFunc, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = attributeBuilderMap[attributeKind];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
Expand All @@ -81,6 +86,7 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,

void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
nb::callable typeCaster, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = typeCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Type caster is already registered with caster: " +
Expand All @@ -90,6 +96,7 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,

void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
nb::callable valueCaster, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = valueCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
Expand All @@ -99,6 +106,7 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,

void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
nb::object pyClass) {
nb::ft_lock_guard lock(mutex);
nb::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
Expand All @@ -110,6 +118,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,

void PyGlobals::registerOperationImpl(const std::string &operationName,
nb::object pyClass, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = operationClassMap[operationName];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
Expand All @@ -121,6 +130,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,

std::optional<nb::callable>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
nb::ft_lock_guard lock(mutex);
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
assert(foundIt->second && "attribute builder is defined");
Expand All @@ -133,6 +143,7 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
nb::ft_lock_guard lock(mutex);
const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
assert(foundIt->second && "type caster is defined");
Expand All @@ -145,6 +156,7 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
nb::ft_lock_guard lock(mutex);
const auto foundIt = valueCasterMap.find(mlirTypeID);
if (foundIt != valueCasterMap.end()) {
assert(foundIt->second && "value caster is defined");
Expand All @@ -158,6 +170,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
if (!loadDialectModule(dialectNamespace))
return std::nullopt;
nb::ft_lock_guard lock(mutex);
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
assert(foundIt->second && "dialect class is defined");
Expand All @@ -175,6 +188,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
if (!loadDialectModule(dialectNamespace))
return std::nullopt;

nb::ft_lock_guard lock(mutex);
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
assert(foundIt->second && "OpView is defined");
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class PyMlirContext {
// Note that this holds a handle, which does not imply ownership.
// Mappings will be removed when the context is destructed.
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();

// Interns all live modules associated with this context. Modules tracked
Expand Down
9 changes: 2 additions & 7 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,8 @@ NB_MODULE(_mlir, m) {
.def_prop_rw("dialect_search_modules",
&PyGlobals::getDialectSearchPrefixes,
&PyGlobals::setDialectSearchPrefixes)
.def(
"append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
},
"module_name"_a)
.def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
"module_name"_a)
.def(
"_check_dialect_module_loaded",
[](PyGlobals &self, const std::string &dialectNamespace) {
Expand Down Expand Up @@ -76,7 +72,6 @@ NB_MODULE(_mlir, m) {
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);

// Dict-stuff the new opClass by name onto the dialect class.
nb::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
Expand Down
2 changes: 1 addition & 1 deletion mlir/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ nanobind>=2.4, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16
ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16
Loading
Loading