Skip to content

Commit 649e51d

Browse files
Joao Gomesfacebook-github-bot
Joao Gomes
authored andcommitted
[fbsync] Add cuda resnet50 test to smoke test (#7020)
Summary: * Add cuda resnet50 test * Fix path * Tune vision smoke test Reviewed By: YosuaMichael Differential Revision: D42046599 fbshipit-source-id: d5aaa4dcd7a07730347d3417f3f94eed70d5a91c
1 parent 68c4f56 commit 649e51d

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

test/smoke_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def smoke_test_torchvision_read_decode() -> None:
2727
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
2828

2929

30-
def smoke_test_torchvision_resnet50_classify() -> None:
31-
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg"))
30+
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
31+
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
3232

3333
# Step 1: Initialize model with the best available weights
3434
weights = ResNet50_Weights.DEFAULT
35-
model = resnet50(weights=weights)
35+
model = resnet50(weights=weights).to(device)
3636
model.eval()
3737

3838
# Step 2: Initialize the inference transforms
@@ -47,7 +47,7 @@ def smoke_test_torchvision_resnet50_classify() -> None:
4747
score = prediction[class_id].item()
4848
category_name = weights.meta["categories"][class_id]
4949
expected_category = "German shepherd"
50-
print(f"{category_name}: {100 * score:.1f}%")
50+
print(f"{category_name} ({device}): {100 * score:.1f}%")
5151
if category_name != expected_category:
5252
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
5353

@@ -57,6 +57,8 @@ def main() -> None:
5757
smoke_test_torchvision()
5858
smoke_test_torchvision_read_decode()
5959
smoke_test_torchvision_resnet50_classify()
60+
if torch.cuda.is_available():
61+
smoke_test_torchvision_resnet50_classify("cuda")
6062

6163

6264
if __name__ == "__main__":

0 commit comments

Comments
 (0)