2
2
import os
3
3
import torch
4
4
import torch .nn as nn
5
- import torch .nn .cuda
5
+
6
+ try :
7
+ import torch .nn .cuda
8
+ cuda = True
9
+ except ImportError :
10
+ print ('Could not import CUDA, skipping' )
11
+ cuda = False
12
+
6
13
import torch .optim as optim
7
14
from torch .autograd import Variable
8
15
from tqdm import tqdm
@@ -51,9 +58,11 @@ def __call__(self, x):
51
58
x = self .relu (self .fc2 (x ))
52
59
return self .softmax (x )
53
60
54
- model = Net ().cuda ()
61
+ model = Net ()
62
+ if cuda == True :
63
+ model .cuda ()
55
64
56
- criterion = nn .ClassNLLCriterion ()
65
+ criterion = nn .NLLLoss ()
57
66
58
67
# Training settings
59
68
BATCH_SIZE = 150
@@ -63,8 +72,13 @@ def __call__(self, x):
63
72
optimizer = optim .SGD (model , lr = 1e-2 , momentum = 0.9 )
64
73
65
74
def train (epoch ):
66
- batch_data = Variable (torch .cuda .FloatTensor (BATCH_SIZE , 1 , 28 , 28 ), requires_grad = False )
67
- batch_targets = Variable (torch .cuda .FloatTensor (BATCH_SIZE ), requires_grad = False )
75
+ batch_data_t = torch .FloatTensor (BATCH_SIZE , 1 , 28 , 28 )
76
+ batch_targets_t = torch .LongTensor (BATCH_SIZE )
77
+ if cuda :
78
+ batch_data_t = batch_data_t .cuda ()
79
+ batch_targets_t = batch_targets_t .cuda ()
80
+ batch_data = Variable (batch_data_t , requires_grad = False )
81
+ batch_targets = Variable (batch_targets_t , requires_grad = False )
68
82
for i in range (0 , training_data .size (0 ), BATCH_SIZE ):
69
83
batch_data .data [:] = training_data [i :i + BATCH_SIZE ]
70
84
batch_targets .data [:] = training_labels [i :i + BATCH_SIZE ]
@@ -76,8 +90,13 @@ def train(epoch):
76
90
77
91
def test (epoch ):
78
92
test_loss = 0
79
- batch_data = Variable (torch .cuda .FloatTensor (TEST_BATCH_SIZE , 1 , 28 , 28 ), volatile = True )
80
- batch_targets = Variable (torch .cuda .FloatTensor (TEST_BATCH_SIZE ), volatile = True )
93
+ batch_data_t = torch .FloatTensor (TEST_BATCH_SIZE , 1 , 28 , 28 )
94
+ batch_targets_t = torch .LongTensor (TEST_BATCH_SIZE )
95
+ if cuda :
96
+ batch_data_t = batch_data_t .cuda ()
97
+ batch_targets_t = batch_targets_t .cuda ()
98
+ batch_data = Variable (batch_data_t , volatile = True )
99
+ batch_targets = Variable (batch_targets_t , volatile = True )
81
100
correct = 0
82
101
for i in range (0 , test_data .size (0 ), TEST_BATCH_SIZE ):
83
102
print ('Testing model: {}/{}' .format (i , test_data .size (0 )), end = '\r ' )
@@ -98,4 +117,3 @@ def test(epoch):
98
117
for epoch in range (1 , NUM_EPOCHS + 1 ):
99
118
train (epoch )
100
119
test (epoch )
101
-
0 commit comments