@@ -125,6 +125,8 @@ def get_version_tag() -> str:
125125import torch # noqa: E402
126126
127127if TORCH_CUDA_ARCH_LIST is None :
128+ HAS_CUDA_V8 = any (torch .cuda .get_device_capability (i )[0 ] >= 8 for i in range (torch .cuda .device_count ()))
129+
128130 got_cuda_v6 = any (torch .cuda .get_device_capability (i )[0 ] >= 6 for i in range (torch .cuda .device_count ()))
129131 got_cuda_between_v6_and_v8 = any (6 <= torch .cuda .get_device_capability (i )[0 ] < 8 for i in range (torch .cuda .device_count ()))
130132
@@ -139,7 +141,8 @@ def get_version_tag() -> str:
139141 if BUILD_CUDA_EXT and not FORCE_BUILD :
140142 if got_cuda_between_v6_and_v8 :
141143 FORCE_BUILD = True
142-
144+ else :
145+ HAS_CUDA_V8 = not ROCM_VERSION and len ([arch for arch in TORCH_CUDA_ARCH_LIST .split () if float (arch .split ('+' )[0 ]) >= 8 ]) > 0
143146
144147if RELEASE_MODE == "1" :
145148 common_setup_kwargs ["version" ] += f"+{ get_version_tag ()} "
@@ -217,22 +220,23 @@ def get_version_tag() -> str:
217220 ),
218221 ]
219222
220- if sys .platform != "win32" :
221- # TODO: VC++: fatal error C1061: compiler limit : blocks nested too deeply
222- marlin_kernel = cpp_ext .CUDAExtension (
223- "gptqmodel_marlin_kernels" ,
224- [
225- "gptqmodel_ext/marlin/marlin_cuda.cpp" ,
226- "gptqmodel_ext/marlin/marlin_cuda_kernel.cu" ,
227- "gptqmodel_ext/marlin/marlin_repack.cu" ,
228- ],
229- extra_link_args = extra_link_args ,
230- extra_compile_args = extra_compile_args ,
231- )
223+ if sys .platform != "win32" :# TODO: VC++: fatal error C1061: compiler limit : blocks nested too deeply
232224 # https://rocm.docs.amd.com/projects/HIPIFY/en/docs-6.1.0/tables/CUDA_Device_API_supported_by_HIP.html
233225 # nv_bfloat16 and nv_bfloat162 (2x bf16) missing replacement in ROCm
234- if not ROCM_VERSION :
226+ if HAS_CUDA_V8 and not ROCM_VERSION :
227+ marlin_kernel = cpp_ext .CUDAExtension (
228+ "gptqmodel_marlin_kernels" ,
229+ [
230+ "gptqmodel_ext/marlin/marlin_cuda.cpp" ,
231+ "gptqmodel_ext/marlin/marlin_cuda_kernel.cu" ,
232+ "gptqmodel_ext/marlin/marlin_repack.cu" ,
233+ ],
234+ extra_link_args = extra_link_args ,
235+ extra_compile_args = extra_compile_args ,
236+ )
235237 extensions .append (marlin_kernel )
238+ elif not HAS_CUDA_V8 :
239+ print (f"marlin kernel only supports compute capability >= 8.0, there's no such cuda device, skipped." )
236240 extensions += [
237241 # TODO: VC++: error lnk2001 unresolved external symbol cublasHgemm
238242 cpp_ext .CUDAExtension (
0 commit comments