File tree 1 file changed +3
-3
lines changed 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -138,15 +138,15 @@ def smoke_test_torchvision_read_decode() -> None:
138
138
raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
139
139
140
140
141
- def smoke_test_torchvision_resnet50_classify () -> None :
141
+ def smoke_test_torchvision_resnet50_classify (device : str = "cpu" ) -> None :
142
142
from torchvision .io import read_image
143
143
from torchvision .models import resnet50 , ResNet50_Weights
144
144
145
- img = read_image (str (SCRIPT_DIR / "assets" / "dog2.jpg" ))
145
+ img = read_image (str (SCRIPT_DIR / "assets" / "dog2.jpg" )). to ( device )
146
146
147
147
# Step 1: Initialize model with the best available weights
148
148
weights = ResNet50_Weights .DEFAULT
149
- model = resnet50 (weights = weights )
149
+ model = resnet50 (weights = weights ). to ( device )
150
150
model .eval ()
151
151
152
152
# Step 2: Initialize the inference transforms
You can’t perform that action at this time.
0 commit comments