Skip to content

Commit dbd0d52

Browse files
committed
Reapply "fix vector.splat -> vector.broadcast" (#170)
This reverts commit b6729e9.
1 parent b0e2b3b commit dbd0d52

File tree

2 files changed

+225
-188
lines changed

2 files changed

+225
-188
lines changed

mlir/extras/runtime/refbackend.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
CONSUME_RETURN_CALLBACK_ATTR = "refbackend_consume_return_callback"
4242
refback_cb_attr = CONSUME_RETURN_CALLBACK_ATTR
4343

44+
_exec_engine_shared_libs = []
45+
4446
if ASYNC_RUNTIME_LIB_PATH := os.getenv("ASYNC_RUNTIME_LIB_PATH"):
4547
ASYNC_RUNTIME_LIB_PATH = Path(ASYNC_RUNTIME_LIB_PATH)
4648
else:
@@ -49,6 +51,9 @@
4951
/ f"{shlib_prefix()}mlir_async_runtime.{shlib_ext()}"
5052
)
5153

54+
if ASYNC_RUNTIME_LIB_PATH.exists():
55+
_exec_engine_shared_libs.append(ASYNC_RUNTIME_LIB_PATH)
56+
5257
if C_RUNNER_UTILS_LIB_PATH := os.getenv("C_RUNNER_UTILS_LIB_PATH"):
5358
C_RUNNER_UTILS_LIB_PATH = Path(C_RUNNER_UTILS_LIB_PATH)
5459
else:
@@ -57,6 +62,9 @@
5762
/ f"{shlib_prefix()}mlir_c_runner_utils.{shlib_ext()}"
5863
)
5964

65+
if C_RUNNER_UTILS_LIB_PATH.exists():
66+
_exec_engine_shared_libs.append(C_RUNNER_UTILS_LIB_PATH)
67+
6068
if RUNNER_UTILS_LIB_PATH := os.getenv("RUNNER_UTILS_LIB_PATH"):
6169
RUNNER_UTILS_LIB_PATH = Path(RUNNER_UTILS_LIB_PATH)
6270
else:
@@ -65,6 +73,21 @@
6573
/ f"{shlib_prefix()}mlir_runner_utils.{shlib_ext()}"
6674
)
6775

76+
if RUNNER_UTILS_LIB_PATH.exists():
77+
_exec_engine_shared_libs.append(RUNNER_UTILS_LIB_PATH)
78+
79+
if CUDA_RUNTIME_LIB_PATH := os.getenv("CUDA_RUNTIME_LIB_PATH"):
80+
CUDA_RUNTIME_LIB_PATH = Path(CUDA_RUNTIME_LIB_PATH)
81+
else:
82+
CUDA_RUNTIME_LIB_PATH = (
83+
Path(_mlir_libs.__file__).parent
84+
/ f"{shlib_prefix()}mlir_cuda_runtime.{shlib_ext()}"
85+
)
86+
87+
# TODO(max): for some reason adding this lib to execengine causes a segfault (or something)
88+
# if CUDA_RUNTIME_LIB_PATH.exists():
89+
# _exec_engine_shared_libs.append(CUDA_RUNTIME_LIB_PATH)
90+
6891

6992
def get_ctype_func(mlir_ret_types):
7093
ctypes_arg = [None]
@@ -202,14 +225,10 @@ def __init__(
202225
shared_lib_paths=None,
203226
):
204227
if shared_lib_paths is None:
205-
shared_lib_paths = []
228+
shared_lib_paths = set()
206229
if platform.system() != "Windows":
207-
shared_lib_paths += [
208-
ASYNC_RUNTIME_LIB_PATH,
209-
C_RUNNER_UTILS_LIB_PATH,
210-
RUNNER_UTILS_LIB_PATH,
211-
]
212-
self.shared_lib_paths = shared_lib_paths
230+
shared_lib_paths |= set(_exec_engine_shared_libs)
231+
self.shared_lib_paths = list(shared_lib_paths)
213232
self.return_func_types = None
214233
self.return_func_name = None
215234

0 commit comments

Comments
 (0)