Skip to content

Commit 7560c51

Browse files
committed
Display cuda version for module
1 parent 193cf07 commit 7560c51

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

test/smoke_test/smoke_test.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,6 @@ def check_nightly_binaries_date(package: str) -> None:
5757
f"Expected {module['name']} to be less then {NIGHTLY_ALLOWED_DELTA} days. But its {date_m_delta}"
5858
)
5959

60-
def check_cuda_version(version: str, dlibary: str):
61-
if version is not None and torch.version.cuda is not None:
62-
version_str = str(version)
63-
m_version = f"{version_str[:-3]}.{version_str[-2]}"
64-
t_version = torch.version.cuda.split(".")
65-
t_version = f"{t_version[0]}.{t_version[1]}"
66-
if m_version != t_version:
67-
raise RuntimeError(
68-
"Detected that PyTorch and {dlibary} were compiled with different CUDA versions. "
69-
f"PyTorch has CUDA version {t_version} whereas {dlibary} has CUDA version {m_version}. "
70-
)
71-
7260
def smoke_test_cuda(package: str) -> None:
7361
if not torch.cuda.is_available() and is_cuda_system:
7462
raise RuntimeError(f"Expected CUDA {gpu_arch_ver}. However CUDA is not loaded.")
@@ -86,7 +74,8 @@ def smoke_test_cuda(package: str) -> None:
8674
for module in MODULES:
8775
imported_module = importlib.import_module(module["name"])
8876
if module["extension"] == "extension":
89-
imported_module.extension._check_cuda_version()
77+
version = imported_module.extension._check_cuda_version()
78+
print(f"{module['name']} CUDA: {version}")
9079
else:
9180
imported_module._extension._check_cuda_version()
9281

0 commit comments

Comments
 (0)