Skip to content

Making CUDA optional for MNIST example #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 25, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mnist/data
*.pyc
34 changes: 26 additions & 8 deletions mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
import os
import torch
import torch.nn as nn
import torch.nn.cuda

try:
import torch.nn.cuda
cuda = True
except ImportError:
print('Could not import CUDA, skipping')
cuda = False

import torch.optim as optim
from torch.autograd import Variable
from tqdm import tqdm
Expand Down Expand Up @@ -51,9 +58,11 @@ def __call__(self, x):
x = self.relu(self.fc2(x))
return self.softmax(x)

model = Net().cuda()
model = Net()
if cuda == True:
model.cuda()

criterion = nn.ClassNLLCriterion()
criterion = nn.NLLLoss()

# Training settings
BATCH_SIZE = 150
Expand All @@ -63,8 +72,13 @@ def __call__(self, x):
optimizer = optim.SGD(model, lr=1e-2, momentum=0.9)

def train(epoch):
batch_data = Variable(torch.cuda.FloatTensor(BATCH_SIZE, 1, 28, 28), requires_grad=False)
batch_targets = Variable(torch.cuda.FloatTensor(BATCH_SIZE), requires_grad=False)
batch_data_t = torch.FloatTensor(BATCH_SIZE, 1, 28, 28)
batch_targets_t = torch.LongTensor(BATCH_SIZE)
if cuda:
batch_data_t = batch_data_t.cuda()
batch_targets_t = batch_targets_t.cuda()
batch_data = Variable(batch_data_t, requires_grad=False)
batch_targets = Variable(batch_targets_t, requires_grad=False)
for i in range(0, training_data.size(0), BATCH_SIZE):
batch_data.data[:] = training_data[i:i+BATCH_SIZE]
batch_targets.data[:] = training_labels[i:i+BATCH_SIZE]
Expand All @@ -76,8 +90,13 @@ def train(epoch):

def test(epoch):
test_loss = 0
batch_data = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE, 1, 28, 28), volatile=True)
batch_targets = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE), volatile=True)
batch_data_t = torch.FloatTensor(TEST_BATCH_SIZE, 1, 28, 28)
batch_targets_t = torch.LongTensor(TEST_BATCH_SIZE)
if cuda:
batch_data_t = batch_data_t.cuda()
batch_targets_t = batch_targets_t.cuda()
batch_data = Variable(batch_data_t, volatile=True)
batch_targets = Variable(batch_targets_t, volatile=True)
correct = 0
for i in range(0, test_data.size(0), TEST_BATCH_SIZE):
print('Testing model: {}/{}'.format(i, test_data.size(0)), end='\r')
Expand All @@ -98,4 +117,3 @@ def test(epoch):
for epoch in range(1, NUM_EPOCHS+1):
train(epoch)
test(epoch)