3
3
import torch
4
4
import torchvision
5
5
import torchaudio
6
+ from pathlib import Path
6
7
7
- def smoke_test_cuda () -> None :
8
- gpu_arch_ver = os .getenv ('GPU_ARCH_VER' )
9
- gpu_arch_type = os . getenv ( 'GPU_ARCH_TYPE' )
10
- is_cuda_system = gpu_arch_type == "cuda"
8
+ gpu_arch_ver = os . getenv ( "GPU_ARCH_VER" )
9
+ gpu_arch_type = os .getenv ("GPU_ARCH_TYPE" )
10
+ is_cuda_system = gpu_arch_type == "cuda"
11
+ SCRIPT_DIR = Path ( __file__ ). parent
11
12
13
+ def smoke_test_cuda () -> None :
12
14
if (not torch .cuda .is_available () and is_cuda_system ):
13
- print (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
14
- sys .exit (1 )
15
+ raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
15
16
if (torch .cuda .is_available ()):
16
17
if (torch .version .cuda != gpu_arch_ver ):
17
- print (f"Wrong CUDA version. Loaded: { torch .version .cuda } Expected: { gpu_arch_ver } " )
18
- sys .exit (1 )
19
- y = torch .randn ([3 ,5 ]).cuda ()
18
+ raise RuntimeError (f"Wrong CUDA version. Loaded: { torch .version .cuda } Expected: { gpu_arch_ver } " )
20
19
print (f"torch cuda: { torch .version .cuda } " )
21
- #todo add cudnn version validation
20
+ # todo add cudnn version validation
22
21
print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
23
22
23
+ def smoke_test_conv2d () -> None :
24
+ import torch .nn as nn
25
+ print ("Calling smoke_test_conv2d" )
26
+ # With square kernels and equal stride
27
+ m = nn .Conv2d (16 , 33 , 3 , stride = 2 )
28
+ # non-square kernels and unequal stride and with padding
29
+ m = nn .Conv2d (16 , 33 , (3 , 5 ), stride = (2 , 1 ), padding = (4 , 2 ))
30
+ # non-square kernels and unequal stride and with padding and dilation
31
+ m = nn .Conv2d (16 , 33 , (3 , 5 ), stride = (2 , 1 ), padding = (4 , 2 ), dilation = (3 , 1 ))
32
+ input = torch .randn (20 , 16 , 50 , 100 )
33
+ output = m (input )
34
+ if (is_cuda_system ):
35
+ print ("Testing smoke_test_conv2d with cuda" )
36
+ conv = nn .Conv2d (3 , 3 , 3 ).cuda ()
37
+ x = torch .randn (1 , 3 , 24 , 24 ).cuda ()
38
+ with torch .cuda .amp .autocast ():
39
+ out = conv (x )
40
+
24
41
def smoke_test_torchvision () -> None :
25
- import torchvision .datasets as dset
26
- import torchvision .transforms
27
- print ('Is torchvision useable?' , all (x is not None for x in [torch .ops .image .decode_png , torch .ops .torchvision .roi_align ]))
42
+ print ("Is torchvision useable?" , all (x is not None for x in [torch .ops .image .decode_png , torch .ops .torchvision .roi_align ]))
43
+
44
+ def smoke_test_torchvision_read_decode () -> None :
45
+ from torchvision .io import read_image
46
+ img_jpg = read_image (str (SCRIPT_DIR / "assets" / "rgb_pytorch.jpg" ))
47
+ if img_jpg .ndim != 3 or img_jpg .numel () < 100 :
48
+ raise RuntimeError (f"Unexpected shape of img_jpg: { img_jpg .shape } " )
49
+ img_png = read_image (str (SCRIPT_DIR / "assets" / "rgb_pytorch.png" ))
50
+ if img_png .ndim != 3 or img_png .numel () < 100 :
51
+ raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
52
+
53
+ def smoke_test_torchvision_resnet50_classify () -> None :
54
+ from torchvision .io import read_image
55
+ from torchvision .models import resnet50 , ResNet50_Weights
56
+
57
+ img = read_image (str (SCRIPT_DIR / "assets" / "dog2.jpg" ))
58
+
59
+ # Step 1: Initialize model with the best available weights
60
+ weights = ResNet50_Weights .DEFAULT
61
+ model = resnet50 (weights = weights )
62
+ model .eval ()
63
+
64
+ # Step 2: Initialize the inference transforms
65
+ preprocess = weights .transforms ()
66
+
67
+ # Step 3: Apply inference preprocessing transforms
68
+ batch = preprocess (img ).unsqueeze (0 )
69
+
70
+ # Step 4: Use the model and print the predicted category
71
+ prediction = model (batch ).squeeze (0 ).softmax (0 )
72
+ class_id = prediction .argmax ().item ()
73
+ score = prediction [class_id ].item ()
74
+ category_name = weights .meta ["categories" ][class_id ]
75
+ expected_category = "German shepherd"
76
+ print (f"{ category_name } : { 100 * score :.1f} %" )
77
+ if (category_name != expected_category ):
78
+ raise RuntimeError (f"Failed ResNet50 classify { category_name } Expected: { expected_category } " )
79
+
28
80
29
81
def smoke_test_torchaudio () -> None :
30
82
import torchaudio .compliance .kaldi # noqa: F401
@@ -36,14 +88,18 @@ def smoke_test_torchaudio() -> None:
36
88
import torchaudio .transforms # noqa: F401
37
89
import torchaudio .utils # noqa: F401
38
90
91
+
39
92
def main () -> None :
40
93
#todo add torch, torchvision and torchaudio tests
41
94
print (f"torch: { torch .__version__ } " )
42
95
print (f"torchvision: { torchvision .__version__ } " )
43
96
print (f"torchaudio: { torchaudio .__version__ } " )
44
97
smoke_test_cuda ()
45
- smoke_test_torchvision ()
98
+ smoke_test_conv2d ()
46
99
smoke_test_torchaudio ()
100
+ smoke_test_torchvision ()
101
+ smoke_test_torchvision_read_decode ()
102
+ smoke_test_torchvision_resnet50_classify ()
47
103
48
104
if __name__ == "__main__" :
49
105
main ()
0 commit comments