Skip to content

Commit 6049f86

Browse files
authored
[Vulkan][Utils] Automatic platform/OS detection (huggingface#569)
To enable AMD gpus on macOS, we need this detection to let the compiler know that we would be needing moltenVK to use this GPU.
1 parent ff649b5 commit 6049f86

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

shark/iree_utils/vulkan_utils.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from os import linesep
1818
from shark.iree_utils._common import run_cmd
1919
import iree.runtime as ireert
20+
from sys import platform
2021

2122

2223
def get_vulkan_device_name():
@@ -31,11 +32,23 @@ def get_vulkan_device_name():
3132
return vulkaninfo_list[0]
3233

3334

35+
def get_os_name():
36+
if platform.startswith("linux"):
37+
return "linux"
38+
elif platform == "darwin":
39+
return "macos"
40+
elif platform == "win32":
41+
return "windows"
42+
else:
43+
print("Cannot detect OS type, defaulting to linux.")
44+
return "linux"
45+
46+
3447
def get_vulkan_triple_flag(extra_args=[]):
3548
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
3649
print(f"Using target triple from command line args")
3750
return None
38-
51+
system_os = get_os_name()
3952
vulkan_device = get_vulkan_device_name()
4053
if all(x in vulkan_device for x in ("Apple", "M1")):
4154
print(f"Found {vulkan_device} Device. Using m1-moltenvk-macos")
@@ -44,20 +57,26 @@ def get_vulkan_triple_flag(extra_args=[]):
4457
print("Found Apple M2 Device. Using m1-moltenvk-macos")
4558
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
4659
elif all(x in vulkan_device for x in ("A100", "SXM4")):
47-
print(f"Found {vulkan_device} Device. Using ampere-rtx3080-linux")
48-
return "-iree-vulkan-target-triple=ampere-rtx3080-linux"
60+
print(
61+
f"Found {vulkan_device} Device. Using ampere-rtx3080-{system_os}"
62+
)
63+
return f"-iree-vulkan-target-triple=ampere-rtx3080-{system_os}"
4964
elif all(x in vulkan_device for x in ("RTX", "3090")):
50-
print(f"Found {vulkan_device} Device. Using ampere-rtx3090-linux")
51-
return "-iree-vulkan-target-triple=ampere-rtx3090-linux"
65+
print(
66+
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
67+
)
68+
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
5269
elif all(x in vulkan_device for x in ("RTX", "4090")):
53-
print(f"Found {vulkan_device} Device. Using ampere-rtx3090-linux")
54-
return "-iree-vulkan-target-triple=ampere-rtx3090-linux"
70+
print(
71+
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
72+
)
73+
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
5574
elif all(x in vulkan_device for x in ("AMD", "7900")):
56-
print(f"Found {vulkan_device} Device. Using rdna3-7900-linux")
57-
return "-iree-vulkan-target-triple=rdna3-7900-linux"
75+
print(f"Found {vulkan_device} Device. Using rdna3-7900-{system_os}")
76+
return f"-iree-vulkan-target-triple=rdna3-7900-{system_os}"
5877
elif "AMD" in vulkan_device:
59-
print("Found AMD device. Using rdna2-unknown-linux")
60-
return "-iree-vulkan-target-triple=rdna2-unknown-linux"
78+
print(f"Found AMD device. Using rdna2-unknown-{system_os}")
79+
return f"-iree-vulkan-target-triple=rdna2-unknown-{system_os}"
6180
else:
6281
print(
6382
"""Optimized kernel for your target device is not added yet.

0 commit comments

Comments
 (0)