|
10 | 10 |
|
11 | 11 | gpu_arch_ver = os.getenv("MATRIX_GPU_ARCH_VERSION") |
12 | 12 | gpu_arch_type = os.getenv("MATRIX_GPU_ARCH_TYPE") |
13 | | -# use installation env variable to tell if it is nightly channel |
14 | | -installation_str = os.getenv("MATRIX_INSTALLATION") |
| 13 | +channel = os.getenv("MATRIX_CHANNEL") |
| 14 | +stable_version = os.getenv("MATRIX_STABLE_VERSION") |
| 15 | + |
15 | 16 | is_cuda_system = gpu_arch_type == "cuda" |
16 | 17 | SCRIPT_DIR = Path(__file__).parent |
17 | 18 | NIGHTLY_ALLOWED_DELTA = 3 |
|
31 | 32 | }, |
32 | 33 | ] |
33 | 34 |
|
| 35 | +def check_version(package: str) -> None: |
| 36 | + # only makes sense to check nightly package where dates are known |
| 37 | + if channel == "nightly": |
| 38 | + check_nightly_binaries_date(options.package) |
| 39 | + else |
| 40 | + if torch.__version__ != stable_version: |
| 41 | + raise RuntimeError( |
| 42 | + f"Torch version mismatch, expected {stable_version} for channel {channel}. But its {torch.__version__}" |
| 43 | + ) |
| 44 | + |
34 | 45 | def check_nightly_binaries_date(package: str) -> None: |
35 | 46 | from datetime import datetime, timedelta |
36 | 47 | format_dt = '%Y%m%d' |
@@ -190,17 +201,13 @@ def main() -> None: |
190 | 201 | ) |
191 | 202 | options = parser.parse_args() |
192 | 203 | print(f"torch: {torch.__version__}") |
| 204 | + check_version(options.package) |
193 | 205 | smoke_test_conv2d() |
194 | 206 | smoke_test_linalg() |
195 | 207 |
|
196 | | - |
197 | 208 | if options.package == "all": |
198 | 209 | smoke_test_modules() |
199 | 210 |
|
200 | | - # only makes sense to check nightly package where dates are known |
201 | | - if installation_str.find("nightly") != -1: |
202 | | - check_nightly_binaries_date(options.package) |
203 | | - |
204 | 211 | smoke_test_cuda(options.package) |
205 | 212 |
|
206 | 213 |
|
|
0 commit comments