Skip to content

Commit 8aee60c

Browse files
committed
Update example
1 parent dbb6542 commit 8aee60c

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

mnist/main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ class Net(nn.Container):
3030
def __init__(self):
3131
super(Net, self).__init__(
3232
conv1 = nn.Conv2d(1, 20, 5, 5),
33-
pool1 = nn.MaxPooling2d(2, 2),
33+
pool1 = nn.MaxPool2d(2, 2),
3434
conv2 = nn.Conv2d(20, 50, 5, 5),
35-
pool2 = nn.MaxPooling2d(2, 2),
35+
pool2 = nn.MaxPool2d(2, 2),
3636
fc1 = nn.Linear(800, 500),
3737
fc2 = nn.Linear(500, 10),
3838
relu = nn.ReLU(),
@@ -73,14 +73,15 @@ def train(epoch):
7373

7474
def test(epoch):
7575
test_loss = 0
76-
batch_data = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE, 1, 28, 28), requires_grad=False)
77-
batch_targets = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE), requires_grad=False)
76+
batch_data = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE, 1, 28, 28), volatile=True)
77+
batch_targets = Variable(torch.cuda.FloatTensor(TEST_BATCH_SIZE), volatile=True)
7878
for i in range(0, test_data.size(0), TEST_BATCH_SIZE):
7979
print('Testing model: {}/{}'.format(i, test_data.size(0)), end='\r')
8080
batch_data.data[:] = test_data[i:i+TEST_BATCH_SIZE]
8181
batch_targets.data[:] = test_labels[i:i+TEST_BATCH_SIZE]
82-
test_loss += criterion(model(batch_data), batch_targets).data[0]
82+
test_loss += criterion(model(batch_data), batch_targets)
8383

84+
test_loss = test_loss.data[0]
8485
test_loss /= (test_data.size(0) / TEST_BATCH_SIZE) # criterion averages over batch size
8586
print('TEST SET RESULTS:' + ' ' * 20)
8687
print('Average loss: {:.4f}'.format(test_loss))

0 commit comments

Comments
 (0)