Skip to content

Commit 64686f1

Browse files
authored
Update main.py
Using cuda.device_count instead of hardcoding the number of GPUs.
1 parent 6b8115a commit 64686f1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363

6464
if use_cuda:
6565
net.cuda()
66-
net = torch.nn.DataParallel(net, device_ids=[0,1,2,3])
66+
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
6767
cudnn.benchmark = True
6868

6969
criterion = nn.CrossEntropyLoss()

0 commit comments

Comments
 (0)