Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gptqmodel/nn_modules/qlinear/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
except ImportError as e:
marlin_import_exception = e


GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
Expand Down Expand Up @@ -307,6 +308,8 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if IS_ROCM:
return False, RuntimeError("marlin kernel is not supported by rocm.")
if not any(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count())):
return False, RuntimeError("marlin kernel requires Compute Capability >= 8.0.")
if marlin_import_exception is not None:
return False, marlin_import_exception
return cls._validate(**args)
Expand Down
32 changes: 18 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def get_version_tag() -> str:
import torch # noqa: E402

if TORCH_CUDA_ARCH_LIST is None:
HAS_CUDA_V8 = any(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count()))

got_cuda_v6 = any(torch.cuda.get_device_capability(i)[0] >= 6 for i in range(torch.cuda.device_count()))
got_cuda_between_v6_and_v8 = any(6 <= torch.cuda.get_device_capability(i)[0] < 8 for i in range(torch.cuda.device_count()))

Expand All @@ -139,7 +141,8 @@ def get_version_tag() -> str:
if BUILD_CUDA_EXT and not FORCE_BUILD:
if got_cuda_between_v6_and_v8:
FORCE_BUILD = True

else:
HAS_CUDA_V8 = not ROCM_VERSION and len([arch for arch in TORCH_CUDA_ARCH_LIST.split() if float(arch.split('+')[0]) >= 8]) > 0

if RELEASE_MODE == "1":
common_setup_kwargs["version"] += f"+{get_version_tag()}"
Expand Down Expand Up @@ -217,22 +220,23 @@ def get_version_tag() -> str:
),
]

if sys.platform != "win32":
# TODO: VC++: fatal error C1061: compiler limit : blocks nested too deeply
marlin_kernel = cpp_ext.CUDAExtension(
"gptqmodel_marlin_kernels",
[
"gptqmodel_ext/marlin/marlin_cuda.cpp",
"gptqmodel_ext/marlin/marlin_cuda_kernel.cu",
"gptqmodel_ext/marlin/marlin_repack.cu",
],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
)
if sys.platform != "win32":# TODO: VC++: fatal error C1061: compiler limit : blocks nested too deeply
# https://rocm.docs.amd.com/projects/HIPIFY/en/docs-6.1.0/tables/CUDA_Device_API_supported_by_HIP.html
# nv_bfloat16 and nv_bfloat162 (2x bf16) missing replacement in ROCm
if not ROCM_VERSION:
if HAS_CUDA_V8 and not ROCM_VERSION:
marlin_kernel = cpp_ext.CUDAExtension(
"gptqmodel_marlin_kernels",
[
"gptqmodel_ext/marlin/marlin_cuda.cpp",
"gptqmodel_ext/marlin/marlin_cuda_kernel.cu",
"gptqmodel_ext/marlin/marlin_repack.cu",
],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
)
extensions.append(marlin_kernel)
elif not HAS_CUDA_V8:
print(f"marlin kernel only supports compute capability >= 8.0, there's no such cuda device, skipped.")
extensions += [
# TODO: VC++: error lnk2001 unresolved external symbol cublasHgemm
cpp_ext.CUDAExtension(
Expand Down
Loading