Skip to content

Commit b1220b3

Browse files
committed
Reapply "fix vector.splat -> vector.broadcast" (#170)
This reverts commit b6729e9.
1 parent b0e2b3b commit b1220b3

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

mlir/extras/runtime/refbackend.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
CONSUME_RETURN_CALLBACK_ATTR = "refbackend_consume_return_callback"
4242
refback_cb_attr = CONSUME_RETURN_CALLBACK_ATTR
4343

44+
_exec_engine_shared_libs = []
45+
4446
if ASYNC_RUNTIME_LIB_PATH := os.getenv("ASYNC_RUNTIME_LIB_PATH"):
4547
ASYNC_RUNTIME_LIB_PATH = Path(ASYNC_RUNTIME_LIB_PATH)
4648
else:
@@ -49,6 +51,9 @@
4951
/ f"{shlib_prefix()}mlir_async_runtime.{shlib_ext()}"
5052
)
5153

54+
if ASYNC_RUNTIME_LIB_PATH.exists():
55+
_exec_engine_shared_libs.append(ASYNC_RUNTIME_LIB_PATH)
56+
5257
if C_RUNNER_UTILS_LIB_PATH := os.getenv("C_RUNNER_UTILS_LIB_PATH"):
5358
C_RUNNER_UTILS_LIB_PATH = Path(C_RUNNER_UTILS_LIB_PATH)
5459
else:
@@ -57,6 +62,9 @@
5762
/ f"{shlib_prefix()}mlir_c_runner_utils.{shlib_ext()}"
5863
)
5964

65+
if C_RUNNER_UTILS_LIB_PATH.exists():
66+
_exec_engine_shared_libs.append(C_RUNNER_UTILS_LIB_PATH)
67+
6068
if RUNNER_UTILS_LIB_PATH := os.getenv("RUNNER_UTILS_LIB_PATH"):
6169
RUNNER_UTILS_LIB_PATH = Path(RUNNER_UTILS_LIB_PATH)
6270
else:
@@ -65,6 +73,20 @@
6573
/ f"{shlib_prefix()}mlir_runner_utils.{shlib_ext()}"
6674
)
6775

76+
if RUNNER_UTILS_LIB_PATH.exists():
77+
_exec_engine_shared_libs.append(RUNNER_UTILS_LIB_PATH)
78+
79+
if CUDA_RUNTIME_LIB_PATH := os.getenv("CUDA_RUNTIME_LIB_PATH"):
80+
CUDA_RUNTIME_LIB_PATH = Path(CUDA_RUNTIME_LIB_PATH)
81+
else:
82+
CUDA_RUNTIME_LIB_PATH = (
83+
Path(_mlir_libs.__file__).parent
84+
/ f"{shlib_prefix()}mlir_cuda_runtime.{shlib_ext()}"
85+
)
86+
87+
if CUDA_RUNTIME_LIB_PATH.exists():
88+
_exec_engine_shared_libs.append(CUDA_RUNTIME_LIB_PATH)
89+
6890

6991
def get_ctype_func(mlir_ret_types):
7092
ctypes_arg = [None]
@@ -202,14 +224,10 @@ def __init__(
202224
shared_lib_paths=None,
203225
):
204226
if shared_lib_paths is None:
205-
shared_lib_paths = []
227+
shared_lib_paths = set()
206228
if platform.system() != "Windows":
207-
shared_lib_paths += [
208-
ASYNC_RUNTIME_LIB_PATH,
209-
C_RUNNER_UTILS_LIB_PATH,
210-
RUNNER_UTILS_LIB_PATH,
211-
]
212-
self.shared_lib_paths = shared_lib_paths
229+
shared_lib_paths |= set(_exec_engine_shared_libs)
230+
self.shared_lib_paths = list(shared_lib_paths)
213231
self.return_func_types = None
214232
self.return_func_name = None
215233

tests/test_nvgpu_nvvm.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from mlir.extras.dialects.ext.gpu import smem_space
2828
from mlir.extras.dialects.ext.llvm import llvm_ptr_t
2929
from mlir.extras.runtime.passes import Pipeline, run_pipeline
30-
from mlir.extras.runtime.refbackend import LLVMJITBackend
30+
from mlir.extras.runtime.refbackend import LLVMJITBackend, CUDA_RUNTIME_LIB_PATH
3131

3232
# noinspection PyUnresolvedReferences
3333
from mlir.extras.testing import (
@@ -202,7 +202,8 @@ def payload():
202202
compute_linspace_val.emit()
203203

204204
@func
205-
def printMemrefF32(x: T.memref(T.f32())): ...
205+
def printMemrefF32(x: T.memref(T.f32())):
206+
...
206207

207208
printMemrefF32_.append(printMemrefF32)
208209

@@ -413,10 +414,11 @@ def main(module: any_op_t()):
413414
# CHECK: }
414415
# CHECK: }
415416

416-
filecheck_with_comments(mod)
417+
mod.operation.verify()
417418

419+
if CUDA_RUNTIME_LIB_PATH.exists():
420+
filecheck_with_comments(mod)
418421

419-
CUDA_RUNTIME_LIB_PATH = Path(_mlir_libs.__file__).parent / f"libmlir_cuda_runtime.so"
420422

421423
NVIDIA_GPU = False
422424
try:
@@ -553,7 +555,8 @@ def payload():
553555
compute_linspace_val.emit()
554556

555557
@func
556-
def printMemrefF32(x: T.memref(T.f32())): ...
558+
def printMemrefF32(x: T.memref(T.f32())):
559+
...
557560

558561
printMemrefF32_.append(printMemrefF32)
559562

0 commit comments

Comments
 (0)