Skip to content

Commit 8465dbe

Browse files
authored
[BE] Fix all flake8 violations in smoke_test.py (#1553)
Namely: - `if(x):` -> `if x:` - `"dev\d+"` -> `"dev\\d+"` - Keep 2 newlines between functions - Add `assert foo is not None` to suppress "variable assigned but not used" warning
1 parent f6d12ba commit 8465dbe

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

test/smoke_test/smoke_test.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
},
3939
]
4040

41+
4142
class Net(nn.Module):
4243
def __init__(self):
4344
super(Net, self).__init__()
@@ -53,6 +54,7 @@ def forward(self, x):
5354
output = self.fc1(x)
5455
return output
5556

57+
5658
def check_version(package: str) -> None:
5759
# only makes sense to check nightly package where dates are known
5860
if channel == "nightly":
@@ -65,32 +67,33 @@ def check_version(package: str) -> None:
6567
else:
6668
print(f"Skip version check for channel {channel} as stable version is None")
6769

70+
6871
def check_nightly_binaries_date(package: str) -> None:
6972
from datetime import datetime, timedelta
7073
format_dt = '%Y%m%d'
7174

72-
torch_str = torch.__version__
73-
date_t_str = re.findall("dev\d+", torch.__version__)
75+
date_t_str = re.findall("dev\\d+", torch.__version__)
7476
date_t_delta = datetime.now() - datetime.strptime(date_t_str[0][3:], format_dt)
7577
if date_t_delta.days >= NIGHTLY_ALLOWED_DELTA:
7678
raise RuntimeError(
7779
f"the binaries are from {date_t_str} and are more than {NIGHTLY_ALLOWED_DELTA} days old!"
7880
)
7981

80-
if(package == "all"):
82+
if package == "all":
8183
for module in MODULES:
8284
imported_module = importlib.import_module(module["name"])
8385
module_version = imported_module.__version__
84-
date_m_str = re.findall("dev\d+", module_version)
86+
date_m_str = re.findall("dev\\d+", module_version)
8587
date_m_delta = datetime.now() - datetime.strptime(date_m_str[0][3:], format_dt)
8688
print(f"Nightly date check for {module['name']} version {module_version}")
8789
if date_m_delta.days > NIGHTLY_ALLOWED_DELTA:
8890
raise RuntimeError(
8991
f"Expected {module['name']} to be less then {NIGHTLY_ALLOWED_DELTA} days. But its {date_m_delta}"
9092
)
9193

94+
9295
def test_cuda_runtime_errors_captured() -> None:
93-
cuda_exception_missed=True
96+
cuda_exception_missed = True
9497
try:
9598
print("Testing test_cuda_runtime_errors_captured")
9699
torch._assert_async(torch.tensor(0, device="cuda"))
@@ -101,14 +104,15 @@ def test_cuda_runtime_errors_captured() -> None:
101104
cuda_exception_missed = False
102105
else:
103106
raise e
104-
if(cuda_exception_missed):
105-
raise RuntimeError( f"Expected CUDA RuntimeError but have not received!")
107+
if cuda_exception_missed:
108+
raise RuntimeError("Expected CUDA RuntimeError but have not received!")
109+
106110

107111
def smoke_test_cuda(package: str, runtime_error_check: str) -> None:
108112
if not torch.cuda.is_available() and is_cuda_system:
109113
raise RuntimeError(f"Expected CUDA {gpu_arch_ver}. However CUDA is not loaded.")
110114

111-
if(package == 'all' and is_cuda_system):
115+
if package == 'all' and is_cuda_system:
112116
for module in MODULES:
113117
imported_module = importlib.import_module(module["name"])
114118
# TBD for vision move extension module to private so it will
@@ -131,12 +135,10 @@ def smoke_test_cuda(package: str, runtime_error_check: str) -> None:
131135
print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")
132136

133137
# torch.compile is available only on Linux and python 3.8-3.10
134-
if (sys.platform == "linux" or sys.platform == "linux2") and sys.version_info < (3, 11, 0) and channel == "release":
135-
smoke_test_compile()
136-
elif (sys.platform == "linux" or sys.platform == "linux2") and channel != "release":
138+
if sys.platform in ["linux", "linux2"] and (sys.version_info < (3, 11, 0) or channel != "release"):
137139
smoke_test_compile()
138140

139-
if(runtime_error_check == "enabled"):
141+
if runtime_error_check == "enabled":
140142
test_cuda_runtime_errors_captured()
141143

142144

@@ -148,6 +150,7 @@ def smoke_test_conv2d() -> None:
148150
m = nn.Conv2d(16, 33, 3, stride=2)
149151
# non-square kernels and unequal stride and with padding
150152
m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
153+
assert m is not None
151154
# non-square kernels and unequal stride and with padding and dilation
152155
basic_conv = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
153156
input = torch.randn(20, 16, 50, 100)
@@ -156,16 +159,19 @@ def smoke_test_conv2d() -> None:
156159
if is_cuda_system:
157160
print("Testing smoke_test_conv2d with cuda")
158161
conv = nn.Conv2d(3, 3, 3).cuda()
159-
x = torch.randn(1, 3, 24, 24).cuda()
162+
x = torch.randn(1, 3, 24, 24, device="cuda")
160163
with torch.cuda.amp.autocast():
161164
out = conv(x)
165+
assert out is not None
162166

163167
supported_dtypes = [torch.float16, torch.float32, torch.float64]
164168
for dtype in supported_dtypes:
165169
print(f"Testing smoke_test_conv2d with cuda for {dtype}")
166170
conv = basic_conv.to(dtype).cuda()
167171
input = torch.randn(20, 16, 50, 100, device="cuda").type(dtype)
168172
output = conv(input)
173+
assert output is not None
174+
169175

170176
def smoke_test_linalg() -> None:
171177
print("Testing smoke_test_linalg")
@@ -189,10 +195,13 @@ def smoke_test_linalg() -> None:
189195
A = torch.randn(20, 16, 50, 100, device="cuda").type(dtype)
190196
torch.linalg.svd(A)
191197

198+
192199
def smoke_test_compile() -> None:
193200
supported_dtypes = [torch.float16, torch.float32, torch.float64]
201+
194202
def foo(x: torch.Tensor) -> torch.Tensor:
195203
return torch.sin(x) + torch.cos(x)
204+
196205
for dtype in supported_dtypes:
197206
print(f"Testing smoke_test_compile for {dtype}")
198207
x = torch.rand(3, 3, device="cuda").type(dtype)
@@ -209,6 +218,7 @@ def foo(x: torch.Tensor) -> torch.Tensor:
209218
model = Net().to(device="cuda")
210219
x_pt2 = torch.compile(model, mode="max-autotune")(x)
211220

221+
212222
def smoke_test_modules():
213223
cwd = os.getcwd()
214224
for module in MODULES:
@@ -224,9 +234,7 @@ def smoke_test_modules():
224234
smoke_test_command, stderr=subprocess.STDOUT, shell=True,
225235
universal_newlines=True)
226236
except subprocess.CalledProcessError as exc:
227-
raise RuntimeError(
228-
f"Module {module['name']} FAIL: {exc.returncode} Output: {exc.output}"
229-
)
237+
raise RuntimeError(f"Module {module['name']} FAIL: {exc.returncode} Output: {exc.output}")
230238
else:
231239
print("Output: \n{}\n".format(output))
232240

0 commit comments

Comments
 (0)