38
38
},
39
39
]
40
40
41
+
41
42
class Net (nn .Module ):
42
43
def __init__ (self ):
43
44
super (Net , self ).__init__ ()
@@ -53,6 +54,7 @@ def forward(self, x):
53
54
output = self .fc1 (x )
54
55
return output
55
56
57
+
56
58
def check_version (package : str ) -> None :
57
59
# only makes sense to check nightly package where dates are known
58
60
if channel == "nightly" :
@@ -65,32 +67,33 @@ def check_version(package: str) -> None:
65
67
else :
66
68
print (f"Skip version check for channel { channel } as stable version is None" )
67
69
70
+
68
71
def check_nightly_binaries_date (package : str ) -> None :
69
72
from datetime import datetime , timedelta
70
73
format_dt = '%Y%m%d'
71
74
72
- torch_str = torch .__version__
73
- date_t_str = re .findall ("dev\d+" , torch .__version__ )
75
+ date_t_str = re .findall ("dev\\ d+" , torch .__version__ )
74
76
date_t_delta = datetime .now () - datetime .strptime (date_t_str [0 ][3 :], format_dt )
75
77
if date_t_delta .days >= NIGHTLY_ALLOWED_DELTA :
76
78
raise RuntimeError (
77
79
f"the binaries are from { date_t_str } and are more than { NIGHTLY_ALLOWED_DELTA } days old!"
78
80
)
79
81
80
- if ( package == "all" ) :
82
+ if package == "all" :
81
83
for module in MODULES :
82
84
imported_module = importlib .import_module (module ["name" ])
83
85
module_version = imported_module .__version__
84
- date_m_str = re .findall ("dev\d+" , module_version )
86
+ date_m_str = re .findall ("dev\\ d+" , module_version )
85
87
date_m_delta = datetime .now () - datetime .strptime (date_m_str [0 ][3 :], format_dt )
86
88
print (f"Nightly date check for { module ['name' ]} version { module_version } " )
87
89
if date_m_delta .days > NIGHTLY_ALLOWED_DELTA :
88
90
raise RuntimeError (
89
91
f"Expected { module ['name' ]} to be less then { NIGHTLY_ALLOWED_DELTA } days. But its { date_m_delta } "
90
92
)
91
93
94
+
92
95
def test_cuda_runtime_errors_captured () -> None :
93
- cuda_exception_missed = True
96
+ cuda_exception_missed = True
94
97
try :
95
98
print ("Testing test_cuda_runtime_errors_captured" )
96
99
torch ._assert_async (torch .tensor (0 , device = "cuda" ))
@@ -101,14 +104,15 @@ def test_cuda_runtime_errors_captured() -> None:
101
104
cuda_exception_missed = False
102
105
else :
103
106
raise e
104
- if (cuda_exception_missed ):
105
- raise RuntimeError ( f"Expected CUDA RuntimeError but have not received!" )
107
+ if cuda_exception_missed :
108
+ raise RuntimeError ("Expected CUDA RuntimeError but have not received!" )
109
+
106
110
107
111
def smoke_test_cuda (package : str , runtime_error_check : str ) -> None :
108
112
if not torch .cuda .is_available () and is_cuda_system :
109
113
raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
110
114
111
- if ( package == 'all' and is_cuda_system ) :
115
+ if package == 'all' and is_cuda_system :
112
116
for module in MODULES :
113
117
imported_module = importlib .import_module (module ["name" ])
114
118
# TBD for vision move extension module to private so it will
@@ -131,12 +135,10 @@ def smoke_test_cuda(package: str, runtime_error_check: str) -> None:
131
135
print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
132
136
133
137
# torch.compile is available only on Linux and python 3.8-3.10
134
- if (sys .platform == "linux" or sys .platform == "linux2" ) and sys .version_info < (3 , 11 , 0 ) and channel == "release" :
135
- smoke_test_compile ()
136
- elif (sys .platform == "linux" or sys .platform == "linux2" ) and channel != "release" :
138
+ if sys .platform in ["linux" , "linux2" ] and (sys .version_info < (3 , 11 , 0 ) or channel != "release" ):
137
139
smoke_test_compile ()
138
140
139
- if ( runtime_error_check == "enabled" ) :
141
+ if runtime_error_check == "enabled" :
140
142
test_cuda_runtime_errors_captured ()
141
143
142
144
@@ -148,6 +150,7 @@ def smoke_test_conv2d() -> None:
148
150
m = nn .Conv2d (16 , 33 , 3 , stride = 2 )
149
151
# non-square kernels and unequal stride and with padding
150
152
m = nn .Conv2d (16 , 33 , (3 , 5 ), stride = (2 , 1 ), padding = (4 , 2 ))
153
+ assert m is not None
151
154
# non-square kernels and unequal stride and with padding and dilation
152
155
basic_conv = nn .Conv2d (16 , 33 , (3 , 5 ), stride = (2 , 1 ), padding = (4 , 2 ), dilation = (3 , 1 ))
153
156
input = torch .randn (20 , 16 , 50 , 100 )
@@ -156,16 +159,19 @@ def smoke_test_conv2d() -> None:
156
159
if is_cuda_system :
157
160
print ("Testing smoke_test_conv2d with cuda" )
158
161
conv = nn .Conv2d (3 , 3 , 3 ).cuda ()
159
- x = torch .randn (1 , 3 , 24 , 24 ). cuda ( )
162
+ x = torch .randn (1 , 3 , 24 , 24 , device = " cuda" )
160
163
with torch .cuda .amp .autocast ():
161
164
out = conv (x )
165
+ assert out is not None
162
166
163
167
supported_dtypes = [torch .float16 , torch .float32 , torch .float64 ]
164
168
for dtype in supported_dtypes :
165
169
print (f"Testing smoke_test_conv2d with cuda for { dtype } " )
166
170
conv = basic_conv .to (dtype ).cuda ()
167
171
input = torch .randn (20 , 16 , 50 , 100 , device = "cuda" ).type (dtype )
168
172
output = conv (input )
173
+ assert output is not None
174
+
169
175
170
176
def smoke_test_linalg () -> None :
171
177
print ("Testing smoke_test_linalg" )
@@ -189,10 +195,13 @@ def smoke_test_linalg() -> None:
189
195
A = torch .randn (20 , 16 , 50 , 100 , device = "cuda" ).type (dtype )
190
196
torch .linalg .svd (A )
191
197
198
+
192
199
def smoke_test_compile () -> None :
193
200
supported_dtypes = [torch .float16 , torch .float32 , torch .float64 ]
201
+
194
202
def foo (x : torch .Tensor ) -> torch .Tensor :
195
203
return torch .sin (x ) + torch .cos (x )
204
+
196
205
for dtype in supported_dtypes :
197
206
print (f"Testing smoke_test_compile for { dtype } " )
198
207
x = torch .rand (3 , 3 , device = "cuda" ).type (dtype )
@@ -209,6 +218,7 @@ def foo(x: torch.Tensor) -> torch.Tensor:
209
218
model = Net ().to (device = "cuda" )
210
219
x_pt2 = torch .compile (model , mode = "max-autotune" )(x )
211
220
221
+
212
222
def smoke_test_modules ():
213
223
cwd = os .getcwd ()
214
224
for module in MODULES :
@@ -224,9 +234,7 @@ def smoke_test_modules():
224
234
smoke_test_command , stderr = subprocess .STDOUT , shell = True ,
225
235
universal_newlines = True )
226
236
except subprocess .CalledProcessError as exc :
227
- raise RuntimeError (
228
- f"Module { module ['name' ]} FAIL: { exc .returncode } Output: { exc .output } "
229
- )
237
+ raise RuntimeError (f"Module { module ['name' ]} FAIL: { exc .returncode } Output: { exc .output } " )
230
238
else :
231
239
print ("Output: \n {}\n " .format (output ))
232
240
0 commit comments