@@ -27,12 +27,12 @@ def smoke_test_torchvision_read_decode() -> None:
27
27
raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
28
28
29
29
30
- def smoke_test_torchvision_resnet50_classify () -> None :
31
- img = read_image (str (SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg" ))
30
+ def smoke_test_torchvision_resnet50_classify (device : str = "cpu" ) -> None :
31
+ img = read_image (str (SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg" )). to ( device )
32
32
33
33
# Step 1: Initialize model with the best available weights
34
34
weights = ResNet50_Weights .DEFAULT
35
- model = resnet50 (weights = weights )
35
+ model = resnet50 (weights = weights ). to ( device )
36
36
model .eval ()
37
37
38
38
# Step 2: Initialize the inference transforms
@@ -47,7 +47,7 @@ def smoke_test_torchvision_resnet50_classify() -> None:
47
47
score = prediction [class_id ].item ()
48
48
category_name = weights .meta ["categories" ][class_id ]
49
49
expected_category = "German shepherd"
50
- print (f"{ category_name } : { 100 * score :.1f} %" )
50
+ print (f"{ category_name } ( { device } ) : { 100 * score :.1f} %" )
51
51
if category_name != expected_category :
52
52
raise RuntimeError (f"Failed ResNet50 classify { category_name } Expected: { expected_category } " )
53
53
@@ -57,6 +57,8 @@ def main() -> None:
57
57
smoke_test_torchvision ()
58
58
smoke_test_torchvision_read_decode ()
59
59
smoke_test_torchvision_resnet50_classify ()
60
+ if torch .cuda .is_available ():
61
+ smoke_test_torchvision_resnet50_classify ("cuda" )
60
62
61
63
62
64
if __name__ == "__main__" :
0 commit comments