Skip to content

Commit ac9f85d

Browse files
committed
Switch to torch rnn, using cudnn
1 parent bd4d385 commit ac9f85d

File tree

3 files changed

+71
-166
lines changed

3 files changed

+71
-166
lines changed

word_language_model/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self):
1313
def addword(self, word):
1414
if word not in self.word2idx:
1515
self.idx2word.append(word)
16-
self.word2idx[word] = len(self.idx2word)
16+
self.word2idx[word] = len(self.idx2word) - 1
1717

1818
return self.word2idx[word]
1919

word_language_model/main.py

Lines changed: 70 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,19 @@
1010
import torch.nn as nn
1111
from torch.autograd import Variable
1212

13-
from rnn_modules import *
1413
import data
1514

1615
parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')
1716

1817
# Data parameters
1918
parser.add_argument('-data' , type=str, default='./data/penn', help='Location of the data corpus' )
2019
# Model parameters.
21-
parser.add_argument('-model' , type=str, default='LSTM' , help='Type of recurrent net. RNN, LSTM, or GRU.')
20+
parser.add_argument('-model' , type=str, default='LSTM' , help='Type of recurrent net. RNN_TANH, RNN_RELU, LSTM, or GRU.')
2221
parser.add_argument('-emsize' , type=int, default=200 , help='Size of word embeddings' )
2322
parser.add_argument('-nhid' , type=int, default=200 , help='Number of hidden units per layer.' )
2423
parser.add_argument('-nlayers' , type=int, default=2 , help='Number of layers.' )
2524
# Optimization parameters.
26-
parser.add_argument('-lr' , type=float, default=1 , help='Initial learning rate.' )
25+
parser.add_argument('-lr' , type=float, default=20 , help='Initial learning rate.' )
2726
parser.add_argument('-clip' , type=float, default=0.5 , help='Gradient clipping.' )
2827
parser.add_argument('-maxepoch' , type=int, default=6 , help='Upper epoch limit.' )
2928
parser.add_argument('-batchsize' , type=int, default=20 , help='Batch size.' )
@@ -49,18 +48,19 @@
4948

5049
corpus = data.Corpus(args.data)
5150

52-
def batchify(data, bsz, bptt):
53-
nbatch = int(math.floor(data.size(0) / bsz / bptt))
54-
data = data.narrow(0, 0, nbatch * bptt * bsz)
51+
def batchify(data, bsz):
52+
nbatch = int(math.floor(data.size(0) / bsz))
53+
data = data.narrow(0, 0, nbatch * bsz)
5554
data = data.view(bsz, -1).t().contiguous()
5655
if args.cuda:
5756
data = data.cuda()
5857
return data
5958

60-
train = batchify(corpus.train, args.batchsize, args.bptt)
61-
valid = batchify(corpus.valid, 10, 1)
62-
test = batchify(corpus.test, 10, 1)
63-
59+
eval_bsz = 10
60+
train = batchify(corpus.train, args.batchsize)
61+
valid = batchify(corpus.valid, eval_bsz)
62+
test = batchify(corpus.test, eval_bsz)
63+
#train = train[:123*args.bptt]
6464
bptt = args.bptt
6565
bsz = args.batchsize
6666

@@ -73,22 +73,10 @@ class RNNModel(nn.Container):
7373
and a decoder. Runs one RNN step at a time.
7474
"""
7575

76-
@staticmethod
77-
def name2module(name):
78-
if name == 'RNN':
79-
return RNN
80-
elif name == 'LSTM':
81-
return LSTM
82-
elif name == 'GRU':
83-
return GRU
84-
else:
85-
error("Unknown RNN module: " + name)
86-
8776
def __init__(self, rnnType, ntoken, ninp, nhid, nlayers):
88-
rnnModule = RNNModel.name2module(rnnType)
8977
super(RNNModel, self).__init__(
9078
encoder = nn.sparse.Embedding(ntoken, ninp),
91-
rnn = StackedRNN(rnnModule, ninp, nhid, nlayers),
79+
rnn = nn.rnn.RNNBase(rnnType, ninp, nhid, nlayers, bias=False),
9280
decoder = nn.Linear(nhid, ntoken),
9381
)
9482

@@ -99,37 +87,48 @@ def __init__(self, rnnType, ntoken, ninp, nhid, nlayers):
9987
self.decoder.bias.data.fill_(0)
10088
self.decoder.weight.data.uniform_(-initrange, initrange)
10189

102-
def forward(self, hidden, input):
90+
def forward(self, input, hidden):
10391
emb = self.encoder(input)
104-
hidden, output = self.rnn(hidden, emb)
105-
decoded = self.decoder(output)
106-
return hidden, decoded
92+
output, hidden = self.rnn(emb, hidden)
93+
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
94+
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
10795

108-
def initHidden(self, bsz):
109-
return self.rnn.initHidden(bsz)
110-
111-
model = RNNModel(args.model, corpus.dic.ntokens(), args.emsize, args.nhid, args.nlayers)
96+
ntokens = corpus.dic.ntokens()
97+
model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers)
11298
if args.cuda:
11399
model.cuda()
114100

115101
criterion = nn.CrossEntropyLoss()
116102

103+
def initHidden(model, bsz):
104+
weight = next(model.parameters()).data
105+
if args.model == 'LSTM':
106+
return (Variable(weight.new(args.nlayers, bsz, args.nhid).zero_()),
107+
Variable(weight.new(args.nlayers, bsz, args.nhid).zero_()))
108+
else:
109+
return Variable(weight.new(args.nlayers, bsz, args.nhid).zero_())
110+
111+
117112
########################################
118113
# TRAINING
119114
########################################
120115

121116
lr = args.lr
122117
clip = args.clip
123-
reportinterval = args.reportint * args.batchsize
118+
reportinterval = args.reportint
119+
124120

125121
# Perform the forward pass only.
126-
def evaluate(model, data, criterion):
122+
def evaluate(model, data, criterion, bsz):
127123
loss = 0
128-
hidden = model.initHidden(data.size(1))
124+
hidden = initHidden(model, bsz)
129125
# Loop over validation data.
130-
for i in range(0, data.size(0) - 1):
131-
hidden, output = model(hidden, Variable(data[i], requires_grad=False))
132-
loss += criterion(output, Variable(data[i+1], requires_grad=False)).data
126+
for i in range(0, data.size(0) - 1, bptt):
127+
seq_len = min(bptt, data.size(0) - 1 - i)
128+
output, hidden = model(Variable(data[i:i+seq_len], requires_grad=False), hidden)
129+
targets = data[i+1:i+seq_len+1].view(-1)
130+
loss += bptt * criterion(output.view(seq_len*bsz, -1), Variable(targets, requires_grad=False)).data
131+
hidden = repackageHidden(hidden)
133132

134133
return loss[0] / data.size(0)
135134

@@ -157,46 +156,61 @@ def repackageHidden(h):
157156
total_loss = 0
158157
epoch_start_time = time.time()
159158
# Start with an initial hidden state.
160-
hidden = model.initHidden(bsz)
159+
hidden = initHidden(model, bsz)
160+
161161
# Loop over the training data.
162162
loss = 0
163163
i = 0
164164
model.zero_grad()
165165

166166
total_loss = 0
167167
start_time = epoch_start_time = time.time()
168-
while i < train.size(0) - 1:
169-
hidden, output = model(hidden, Variable(train[i], requires_grad=False))
170-
loss += criterion(output, Variable(train[i+1], requires_grad=False))
171-
i += 1
168+
ntokens = corpus.dic.ntokens()
169+
for batch, i in enumerate(range(0, train.size(0) - 1, bptt)):
170+
seq_len = min(bptt, train.size(0) - 1 - i)
171+
output, hidden = model(Variable(train[i:i+seq_len], requires_grad=False), hidden)
172+
targets = train[i+1:i+seq_len+1].view(-1)
173+
loss = criterion(output.view(-1, ntokens), Variable(targets, requires_grad=False))
174+
175+
# FIXME: this is the result of a double bug
176+
# bug #1: you can't have dangling nodes in the graph to call backward
177+
# bug #2: hidden.sum() doesn't work, gives me an error in backward, which I can't reproduce in a simple way
178+
# File "/data/users/alerer/pytorch/pytorch/torch/autograd/variable.py", line 82, in backward
179+
# self._execution_engine.run_backward(self, gradient, retain_variables)
180+
# File "/data/users/alerer/pytorch/pytorch/torch/autograd/functions/reduce.py", line 27, in backward
181+
# return grad_output.new(*self.input_size).fill_(grad_output[0])
182+
#ValueError: fill_ recieved an invalid combination of argument types - got (torch.cuda.FloatTensor), but expected (float value)
183+
if args.model == 'LSTM':
184+
loss += 0*hidden[0].sum(0).sum(1).sum(2)
185+
loss += 0*hidden[1].sum(0).sum(1).sum(2)
186+
else:
187+
loss += 0*hidden.sum(0).sum(1).sum(2)
172188

173-
if i % bptt == 0:
174-
loss.backward()
189+
loss.backward()
175190

176-
clipped_lr = lr * clipGradient(model, args.clip)
191+
clipped_lr = lr * clipGradient(model, args.clip)
177192

178-
for p in model.parameters():
179-
p.data.sub_(p.grad.mul(clipped_lr))
193+
for p in model.parameters():
194+
p.data.sub_(p.grad.mul(clipped_lr))
180195

181-
hidden = repackageHidden(hidden)
182-
model.zero_grad()
183-
total_loss += loss.data
184-
loss = 0
196+
hidden = repackageHidden(hidden)
197+
model.zero_grad()
198+
total_loss += loss.data
199+
loss = 0
185200

186-
if i % reportinterval == 0:
201+
if batch % reportinterval == 0 and batch > 0:
187202
cur_loss = total_loss[0] / reportinterval
188203
elapsed = time.time() - start_time
189204
print(
190205
('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.6f} | ms/batch {:5.2f} | '
191206
+ 'train loss {:5.2f} | train ppl {:8.2f}').format(
192-
epoch, i // bptt), train.size(0) // bptt, lr,
193-
elapsed * 1000 / reportinterval * bptt,
207+
epoch, batch, train.size(0) // bptt, lr, elapsed * 1000 / reportinterval,
194208
cur_loss, math.exp(cur_loss)
195209
))
196210
total_loss = 0
197211
start_time = time.time()
198212

199-
val_loss = evaluate(model, valid, criterion)
213+
val_loss = evaluate(model, valid, criterion, eval_bsz)
200214

201215
print(
202216
'| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
@@ -210,7 +224,7 @@ def repackageHidden(h):
210224
prev_loss = val_loss
211225

212226
# Run on test data.
213-
test_loss = evaluate(model, test, criterion)
227+
test_loss = evaluate(model, test, criterion, eval_bsz)
214228
print(
215229
'| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
216230
test_loss, math.exp(test_loss)

word_language_model/rnn_modules.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

0 commit comments

Comments
 (0)