diff --git a/setup.py b/setup.py index 277db603b..6179a4b01 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,8 @@ from setuptools import find_packages, setup +import torch + try: from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel except BaseException: @@ -38,6 +40,10 @@ ROCM_VERSION = os.environ.get('ROCM_VERSION', None) SKIP_ROCM_VERSION_CHECK = os.environ.get('SKIP_ROCM_VERSION_CHECK', None) +if ROCM_VERSION is None and torch.version.hip: + ROCM_VERSION = ".".join(torch.version.hip.split(".")[:2]) # print(torch.version.hip) -> 6.3.42131-fa1d09cbd + os.environ["ROCM_VERSION"] = ROCM_VERSION + if ROCM_VERSION is not None and float(ROCM_VERSION) < 6.2 and not SKIP_ROCM_VERSION_CHECK: sys.exit( "GPTQModel's compatibility with ROCM versions below 6.2 has not been verified. If you wish to proceed, please set the SKIP_ROCM_VERSION_CHECK environment." @@ -101,8 +107,6 @@ def get_version_tag() -> str: if not BUILD_CUDA_EXT: return "cpu" - import torch - if ROCM_VERSION: return f"rocm{ROCM_VERSION}" @@ -124,9 +128,7 @@ def get_version_tag() -> str: if not os.getenv("CI"): with open('requirements.txt') as f: requirements = [line.strip() for line in f if line.strip()] - #subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"]) -import torch # noqa: E402 if TORCH_CUDA_ARCH_LIST is None: HAS_CUDA_V8 = any(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count()))