Skip to content

Commit 0801e1a

Browse files
check CC >= 8 for marlin, fixed #1092 (#1093)
* check cuda v8 for marlin * check cuda 8 for installation * update msg * update skip marlin msg * check rocm first * check not ROCM_VERSION
1 parent ede890c commit 0801e1a

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

gptqmodel/nn_modules/qlinear/marlin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
except ImportError as e:
3232
marlin_import_exception = e
3333

34+
3435
GPTQ_MARLIN_TILE = 16
3536
GPTQ_MARLIN_MIN_THREAD_N = 64
3637
GPTQ_MARLIN_MIN_THREAD_K = 128
@@ -307,6 +308,8 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
307308
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
308309
if IS_ROCM:
309310
return False, RuntimeError("marlin kernel is not supported by rocm.")
311+
if not any(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count())):
312+
return False, RuntimeError("marlin kernel requires Compute Capability >= 8.0.")
310313
if marlin_import_exception is not None:
311314
return False, marlin_import_exception
312315
return cls._validate(**args)

setup.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def get_version_tag() -> str:
125125
import torch # noqa: E402
126126

127127
if TORCH_CUDA_ARCH_LIST is None:
128+
HAS_CUDA_V8 = any(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count()))
129+
128130
got_cuda_v6 = any(torch.cuda.get_device_capability(i)[0] >= 6 for i in range(torch.cuda.device_count()))
129131
got_cuda_between_v6_and_v8 = any(6 <= torch.cuda.get_device_capability(i)[0] < 8 for i in range(torch.cuda.device_count()))
130132

@@ -139,7 +141,8 @@ def get_version_tag() -> str:
139141
if BUILD_CUDA_EXT and not FORCE_BUILD:
140142
if got_cuda_between_v6_and_v8:
141143
FORCE_BUILD = True
142-
144+
else:
145+
HAS_CUDA_V8 = not ROCM_VERSION and len([arch for arch in TORCH_CUDA_ARCH_LIST.split() if float(arch.split('+')[0]) >= 8]) > 0
143146

144147
if RELEASE_MODE == "1":
145148
common_setup_kwargs["version"] += f"+{get_version_tag()}"
@@ -217,22 +220,23 @@ def get_version_tag() -> str:
217220
),
218221
]
219222

220-
if sys.platform != "win32":
221-
# TODO: VC++: fatal error C1061: compiler limit : blocks nested too deeply
222-
marlin_kernel = cpp_ext.CUDAExtension(
223-
"gptqmodel_marlin_kernels",
224-
[
225-
"gptqmodel_ext/marlin/marlin_cuda.cpp",
226-
"gptqmodel_ext/marlin/marlin_cuda_kernel.cu",
227-
"gptqmodel_ext/marlin/marlin_repack.cu",
228-
],
229-
extra_link_args=extra_link_args,
230-
extra_compile_args=extra_compile_args,
231-
)
223+
if sys.platform != "win32":# TODO: VC++: fatal error C1061: compiler limit : blocks nested too deeply
232224
# https://rocm.docs.amd.com/projects/HIPIFY/en/docs-6.1.0/tables/CUDA_Device_API_supported_by_HIP.html
233225
# nv_bfloat16 and nv_bfloat162 (2x bf16) missing replacement in ROCm
234-
if not ROCM_VERSION:
226+
if HAS_CUDA_V8 and not ROCM_VERSION:
227+
marlin_kernel = cpp_ext.CUDAExtension(
228+
"gptqmodel_marlin_kernels",
229+
[
230+
"gptqmodel_ext/marlin/marlin_cuda.cpp",
231+
"gptqmodel_ext/marlin/marlin_cuda_kernel.cu",
232+
"gptqmodel_ext/marlin/marlin_repack.cu",
233+
],
234+
extra_link_args=extra_link_args,
235+
extra_compile_args=extra_compile_args,
236+
)
235237
extensions.append(marlin_kernel)
238+
elif not HAS_CUDA_V8:
239+
print(f"marlin kernel only supports compute capability >= 8.0, there's no such cuda device, skipped.")
236240
extensions += [
237241
# TODO: VC++: error lnk2001 unresolved external symbol cublasHgemm
238242
cpp_ext.CUDAExtension(

0 commit comments

Comments
 (0)