Skip to content

Commit 3b93537

Browse files
committed
Option to run resnet classifier on specific device
1 parent 6bc0bc2 commit 3b93537

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/smoke_test/smoke_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,15 @@ def smoke_test_torchvision_read_decode() -> None:
138138
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
139139

140140

141-
def smoke_test_torchvision_resnet50_classify() -> None:
141+
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
142142
from torchvision.io import read_image
143143
from torchvision.models import resnet50, ResNet50_Weights
144144

145-
img = read_image(str(SCRIPT_DIR / "assets" / "dog2.jpg"))
145+
img = read_image(str(SCRIPT_DIR / "assets" / "dog2.jpg")).to(device)
146146

147147
# Step 1: Initialize model with the best available weights
148148
weights = ResNet50_Weights.DEFAULT
149-
model = resnet50(weights=weights)
149+
model = resnet50(weights=weights).to(device)
150150
model.eval()
151151

152152
# Step 2: Initialize the inference transforms

0 commit comments

Comments
 (0)