Skip to content

Commit 77a6ec7

Browse files
committed
fix mnist example
1 parent 8aee60c commit 77a6ec7

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

mnist/main.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import print_function
2+
import os
23
import torch
34
import torch.nn as nn
45
import torch.nn.cuda
@@ -9,6 +10,9 @@
910
def print_header(msg):
1011
print('===>', msg)
1112

13+
if not os.path.exists('data/processed/training.pt'):
14+
import data
15+
1216
# Data
1317
print_header('Loading data')
1418
with open('data/processed/training.pt', 'rb') as f:
@@ -29,9 +33,9 @@ def print_header(msg):
2933
class Net(nn.Container):
3034
def __init__(self):
3135
super(Net, self).__init__(
32-
conv1 = nn.Conv2d(1, 20, 5, 5),
36+
conv1 = nn.Conv2d(1, 20, 5),
3337
pool1 = nn.MaxPool2d(2, 2),
34-
conv2 = nn.Conv2d(20, 50, 5, 5),
38+
conv2 = nn.Conv2d(20, 50, 5),
3539
pool2 = nn.MaxPool2d(2, 2),
3640
fc1 = nn.Linear(800, 500),
3741
fc2 = nn.Linear(500, 10),
@@ -56,35 +60,40 @@ def __call__(self, x):
5660
TEST_BATCH_SIZE = 1000
5761
NUM_EPOCHS = 2
5862

59-
optimizer = optim.SGD((model, criterion), lr=1e-2, momentum=0.9)
63+
optimizer = optim.SGD(model, lr=1e-2, momentum=0.9)
6064

6165
def train(epoch):
6266
batch_data = Variable(torch.cuda.FloatTensor(BATCH_SIZE, 1, 28, 28), requires_grad=False)
6367
batch_targets = Variable(torch.cuda.FloatTensor(BATCH_SIZE), requires_grad=False)
6468
for i in range(0, training_data.size(0), BATCH_SIZE):
6569
batch_data.data[:] = training_data[i:i+BATCH_SIZE]
6670
batch_targets.data[:] = training_labels[i:i+BATCH_SIZE]
67-
loss = optimizer.step(batch_data, batch_targets)
68-
model.zero_grad_parameters()
71+
loss = optimizer.step(lambda: criterion(model(batch_data), batch_targets))
6972

7073
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(epoch,
7174
i+BATCH_SIZE, training_data.size(0),
72-
(i+BATCH_SIZE)/training_data.size(0)*100, loss))
75+
float(i+BATCH_SIZE)/training_data.size(0)*100, loss))
7376

7477
def test(epoch):
7578
test_loss = 0
7679
batch_data = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE, 1, 28, 28), volatile=True)
7780
batch_targets = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE), volatile=True)
81+
correct = 0
7882
for i in range(0, test_data.size(0), TEST_BATCH_SIZE):
7983
print('Testing model: {}/{}'.format(i, test_data.size(0)), end='\r')
8084
batch_data.data[:] = test_data[i:i+TEST_BATCH_SIZE]
8185
batch_targets.data[:] = test_labels[i:i+TEST_BATCH_SIZE]
82-
test_loss += criterion(model(batch_data), batch_targets)
86+
output = model(batch_data)
87+
test_loss += criterion(output, batch_targets)
88+
pred = output.data.max(1)[1]
89+
correct += pred.long().eq(batch_targets.data.long()).sum()
8390

8491
test_loss = test_loss.data[0]
8592
test_loss /= (test_data.size(0) / TEST_BATCH_SIZE) # criterion averages over batch size
8693
print('TEST SET RESULTS:' + ' ' * 20)
87-
print('Average loss: {:.4f}'.format(test_loss))
94+
print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
95+
test_loss, correct, test_data.size(0),
96+
float(correct)/test_data.size(0)*100))
8897

8998
for epoch in range(1, NUM_EPOCHS+1):
9099
train(epoch)

0 commit comments

Comments
 (0)