Skip to content

Commit ac8a4cf

Browse files
authored
Add cuda and date check to smoke test (#1145)
1 parent b212046 commit ac8a4cf

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

.github/actions/validate-windows-binary/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ runs:
3838
env:
3939
GPU_ARCH_VER: ${{ inputs.gpu_arch_ver }}
4040
GPU_ARCH_TYPE: ${{ inputs.gpu_arch_type }}
41+
INSTALLATION: ${{ inputs.installation }}
4142
CUDA_VER: ${{ inputs.desired_cuda }}
4243
run: |
4344
conda install numpy pillow python=${{ inputs.python_version }}

test/smoke_test/smoke_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,57 @@
11
import os
2+
import re
23
import sys
34
import torch
5+
# the following import would invoke
6+
# _check_cuda_version()
7+
# via torchvision.extension._check_cuda_version()
48
import torchvision
59
import torchaudio
610
from pathlib import Path
711

812
gpu_arch_ver = os.getenv("GPU_ARCH_VER")
913
gpu_arch_type = os.getenv("GPU_ARCH_TYPE")
14+
# use installation env variable to tell if it is nightly channel
15+
installation_str = os.getenv("INSTALLATION")
1016
is_cuda_system = gpu_arch_type == "cuda"
1117
SCRIPT_DIR = Path(__file__).parent
1218

19+
# helper function to return the conda list output, e.g.
20+
# torchaudio 0.13.0.dev20220922 py39_cu102 pytorch-nightly
21+
def get_anaconda_output_for_package(pkg_name_str):
22+
import subprocess as sp
23+
24+
# ignore the header row:
25+
# Name Version Build Channel
26+
cmd = 'conda list -f ' + pkg_name_str
27+
output = sp.getoutput(cmd)
28+
# Get the last line only
29+
return output.strip().split('\n')[-1]
30+
31+
def check_nightly_binaries_date() -> None:
32+
torch_str = torch.__version__
33+
ta_str = torchaudio.__version__
34+
tv_str = torchvision.__version__
35+
36+
date_t_str = re.findall('dev\d+', torch.__version__ )
37+
date_ta_str = re.findall('dev\d+', torchaudio.__version__ )
38+
date_tv_str = re.findall('dev\d+', torchvision.__version__ )
39+
40+
# check that the above three lists are equal and none of them is empty
41+
if not date_t_str or not date_t_str == date_ta_str == date_tv_str:
42+
raise RuntimeError(f"Expected torch, torchaudio, torchvision to be the same date. But they are from {date_t_str}, {date_ta_str}, {date_tv_str} respectively")
43+
44+
# check that the date is recent, at this point, date_torch_str is not empty
45+
binary_date_str = date_t_str[0][3:]
46+
from datetime import datetime
47+
48+
binary_date_obj = datetime.strptime(binary_date_str, '%Y%m%d').date()
49+
today_obj = datetime.today().date()
50+
delta = today_obj - binary_date_obj
51+
if delta.days >= 2:
52+
raise RuntimeError(f"the binaries are from {binary_date_obj} and are more than 2 days old!")
53+
54+
1355
def smoke_test_cuda() -> None:
1456
if(not torch.cuda.is_available() and is_cuda_system):
1557
raise RuntimeError(f"Expected CUDA {gpu_arch_ver}. However CUDA is not loaded.")
@@ -19,6 +61,19 @@ def smoke_test_cuda() -> None:
1961
print(f"torch cuda: {torch.version.cuda}")
2062
# todo add cudnn version validation
2163
print(f"torch cudnn: {torch.backends.cudnn.version()}")
64+
print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")
65+
66+
if installation_str.find('nightly') != -1:
67+
# just print out cuda version, as version check were already performed during import
68+
print(f"torchvision cuda: {torch.ops.torchvision._cuda_version()}")
69+
print(f"torchaudio cuda: {torch.ops.torchaudio.cuda_version()}")
70+
else:
71+
# torchaudio runtime added the cuda verison check on 09/23/2022 via
72+
# https://github.com/pytorch/audio/pull/2707
73+
# so relying on anaconda output for pytorch-test and pytorch channel
74+
torchaudio_allstr = get_anaconda_output_for_package(torchaudio.__name__)
75+
if is_cuda_system and 'cu'+str(gpu_arch_ver).replace(".", "") not in torchaudio_allstr:
76+
raise RuntimeError(f"CUDA version issue. Loaded: {torchaudio_allstr} Expected: {gpu_arch_ver}")
2277

2378
def smoke_test_conv2d() -> None:
2479
import torch.nn as nn
@@ -95,6 +150,11 @@ def main() -> None:
95150
print(f"torchvision: {torchvision.__version__}")
96151
print(f"torchaudio: {torchaudio.__version__}")
97152
smoke_test_cuda()
153+
154+
# only makes sense to check nightly package where dates are known
155+
if installation_str.find('nightly') != -1:
156+
check_nightly_binaries_date()
157+
98158
smoke_test_conv2d()
99159
smoke_test_torchaudio()
100160
smoke_test_torchvision()

0 commit comments

Comments
 (0)