Skip to content

Commit f1d9dbc

Browse files
authored
Merge pull request #3 from clementfarabet/master
Making CUDA optional for MNIST example
2 parents 30cea27 + 960518e commit f1d9dbc

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
mnist/data
2+
*.pyc

mnist/main.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
import os
33
import torch
44
import torch.nn as nn
5-
import torch.nn.cuda
5+
6+
try:
7+
import torch.nn.cuda
8+
cuda = True
9+
except ImportError:
10+
print('Could not import CUDA, skipping')
11+
cuda = False
12+
613
import torch.optim as optim
714
from torch.autograd import Variable
815
from tqdm import tqdm
@@ -51,9 +58,11 @@ def __call__(self, x):
5158
x = self.relu(self.fc2(x))
5259
return self.softmax(x)
5360

54-
model = Net().cuda()
61+
model = Net()
62+
if cuda == True:
63+
model.cuda()
5564

56-
criterion = nn.ClassNLLCriterion()
65+
criterion = nn.NLLLoss()
5766

5867
# Training settings
5968
BATCH_SIZE = 150
@@ -63,8 +72,13 @@ def __call__(self, x):
6372
optimizer = optim.SGD(model, lr=1e-2, momentum=0.9)
6473

6574
def train(epoch):
66-
batch_data = Variable(torch.cuda.FloatTensor(BATCH_SIZE, 1, 28, 28), requires_grad=False)
67-
batch_targets = Variable(torch.cuda.FloatTensor(BATCH_SIZE), requires_grad=False)
75+
batch_data_t = torch.FloatTensor(BATCH_SIZE, 1, 28, 28)
76+
batch_targets_t = torch.LongTensor(BATCH_SIZE)
77+
if cuda:
78+
batch_data_t = batch_data_t.cuda()
79+
batch_targets_t = batch_targets_t.cuda()
80+
batch_data = Variable(batch_data_t, requires_grad=False)
81+
batch_targets = Variable(batch_targets_t, requires_grad=False)
6882
for i in range(0, training_data.size(0), BATCH_SIZE):
6983
batch_data.data[:] = training_data[i:i+BATCH_SIZE]
7084
batch_targets.data[:] = training_labels[i:i+BATCH_SIZE]
@@ -76,8 +90,13 @@ def train(epoch):
7690

7791
def test(epoch):
7892
test_loss = 0
79-
batch_data = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE, 1, 28, 28), volatile=True)
80-
batch_targets = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE), volatile=True)
93+
batch_data_t = torch.FloatTensor(TEST_BATCH_SIZE, 1, 28, 28)
94+
batch_targets_t = torch.LongTensor(TEST_BATCH_SIZE)
95+
if cuda:
96+
batch_data_t = batch_data_t.cuda()
97+
batch_targets_t = batch_targets_t.cuda()
98+
batch_data = Variable(batch_data_t, volatile=True)
99+
batch_targets = Variable(batch_targets_t, volatile=True)
81100
correct = 0
82101
for i in range(0, test_data.size(0), TEST_BATCH_SIZE):
83102
print('Testing model: {}/{}'.format(i, test_data.size(0)), end='\r')
@@ -98,4 +117,3 @@ def test(epoch):
98117
for epoch in range(1, NUM_EPOCHS+1):
99118
train(epoch)
100119
test(epoch)
101-

0 commit comments

Comments
 (0)