Skip to content

Commit 2d47f3f

Browse files
committed
Add multiple layers to RNN
1 parent c456de8 commit 2d47f3f

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

word_language_model/main.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,24 @@
1616
parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')
1717

1818
# Data parameters
19-
parser.add_argument('-data' , type=str, default='./data/penn', help='Location of the data corpus' )
19+
parser.add_argument('-data' , type=str, default='./data/penn', help='Location of the data corpus' )
2020
# Model parameters.
21-
parser.add_argument('-model' , type=str, default='RNN' , help='Type of recurrent net. RNN, LSTM, or GRU.' )
22-
parser.add_argument('-emsize' , type=int, default=200 , help='Size of word embeddings' )
23-
parser.add_argument('-nhid' , type=int, default=200 , help='Number of hidden units per layer.' )
21+
parser.add_argument('-model' , type=str, default='LSTM' , help='Type of recurrent net. RNN, LSTM, or GRU.')
22+
parser.add_argument('-emsize' , type=int, default=200 , help='Size of word embeddings' )
23+
parser.add_argument('-nhid' , type=int, default=200 , help='Number of hidden units per layer.' )
24+
parser.add_argument('-nlayers' , type=int, default=2 , help='Number of layers.' )
2425
# Optimization parameters.
25-
parser.add_argument('-lr' , type=float, default=20 , help='Initial learning rate.' )
26-
parser.add_argument('-clip' , type=float, default=0.5 , help='Gradient clipping.' )
27-
parser.add_argument('-maxepoch' , type=int, default=6 , help='Upper epoch limit.' )
28-
parser.add_argument('-batchsize' , type=int, default=20 , help='Batch size.' )
29-
parser.add_argument('-bptt' , type=int, default=20 , help='Sequence length.' )
26+
parser.add_argument('-lr' , type=float, default=20 , help='Initial learning rate.' )
27+
parser.add_argument('-clip' , type=float, default=0.5 , help='Gradient clipping.' )
28+
parser.add_argument('-maxepoch' , type=int, default=6 , help='Upper epoch limit.' )
29+
parser.add_argument('-batchsize' , type=int, default=20 , help='Batch size.' )
30+
parser.add_argument('-bptt' , type=int, default=20 , help='Sequence length.' )
3031
# Device parameters.
31-
parser.add_argument('-seed' , type=int, default=1111 , help='Random seed.' )
32-
parser.add_argument('-cuda' , action='store_true' , help='Use CUDA.' )
32+
parser.add_argument('-seed' , type=int, default=1111 , help='Random seed.' )
33+
parser.add_argument('-cuda' , action='store_true' , help='Use CUDA.' )
3334
# Misc parameters.
34-
parser.add_argument('-reportint' , type=int, default=1000 , help='Report interval.' )
35-
parser.add_argument('-save' , type=str, default='model.pt' , help='Path to save the final model.' )
35+
parser.add_argument('-reportint' , type=int, default=200 , help='Report interval.' )
36+
parser.add_argument('-save' , type=str, default='model.pt' , help='Path to save the final model.' )
3637
args = parser.parse_args()
3738

3839
# Set the random seed manually for reproducibility.
@@ -59,8 +60,6 @@ def batchify(data, bsz, bptt):
5960
train = batchify(corpus.train, args.batchsize, args.bptt)
6061
valid = batchify(corpus.valid, 10, 1)
6162
test = batchify(corpus.test, 10, 1)
62-
train = train[:10000]
63-
valid = valid[:100]
6463

6564
bptt = args.bptt
6665
bsz = args.batchsize
@@ -87,11 +86,11 @@ def name2module(name):
8786
else:
8887
error("Unknown RNN module: " + name)
8988

90-
def __init__(self, rnnType, ntoken, ninp, nhid):
89+
def __init__(self, rnnType, ntoken, ninp, nhid, nlayers):
9190
rnnModule = RNNModel.name2module(rnnType)
9291
super(RNNModel, self).__init__(
9392
encoder = nn.sparse.Embedding(ntoken, ninp),
94-
rnn = rnnModule(ninp, nhid),
93+
rnn = StackedRNN(rnnModule, ninp, nhid, nlayers),
9594
decoder = nn.Linear(nhid, ntoken),
9695
)
9796

@@ -110,7 +109,7 @@ def __call__(self, hidden, input):
110109
def initHidden(self, bsz):
111110
return self.rnn.initHidden(bsz)
112111

113-
model = RNNModel(args.model, corpus.dic.ntokens(), args.emsize, args.nhid)
112+
model = RNNModel(args.model, corpus.dic.ntokens(), args.emsize, args.nhid, args.nlayers)
114113
if args.cuda:
115114
model.cuda()
116115

@@ -122,7 +121,7 @@ def initHidden(self, bsz):
122121

123122
lr = args.lr
124123
clip = args.clip
125-
reportinterval = args.reportint
124+
reportinterval = args.reportint * args.batchsize
126125

127126
# Perform the forward pass only.
128127
def evaluate(model, data, criterion):
@@ -151,7 +150,7 @@ def repackageHidden(h):
151150
if type(h) == Variable:
152151
return Variable(h.data)
153152
else:
154-
return tuple(repackageVariable(v) for v in h)
153+
return tuple(repackageHidden(v) for v in h)
155154

156155
# Loop over epochs.
157156
prev_loss = None
@@ -167,6 +166,9 @@ def repackageHidden(h):
167166

168167
total_loss = 0
169168
start_time = epoch_start_time = time.time()
169+
# import cProfile, pstats, StringIO
170+
# pr = cProfile.Profile()
171+
# pr.enable()
170172
while i < train.size(0) - 1:
171173
hidden, output = model(hidden, Variable(train[i], requires_grad=False))
172174
loss += criterion(output, Variable(train[i+1], requires_grad=False))
@@ -182,26 +184,31 @@ def repackageHidden(h):
182184

183185
hidden = repackageHidden(hidden)
184186
model.zero_grad()
185-
total_loss += loss.data[0]
187+
total_loss += loss.data
186188
loss = 0
187189

188190
if i % reportinterval == 0:
189-
cur_loss = total_loss / reportinterval
191+
cur_loss = total_loss[0] / reportinterval
190192
elapsed = time.time() - start_time
191193
print(
192194
('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.6f} | ms/batch {:5.2f} | '
193195
+ 'train loss {:5.2f} | train ppl {:8.2f}').format(
194-
epoch, i, train.size(0), lr, elapsed * 1000 / reportinterval * bsz,
196+
epoch, i / bptt, train.size(0) / bptt, lr, elapsed * 1000 / reportinterval,
195197
cur_loss, math.exp(cur_loss)
196198
))
197199
total_loss = 0
198200
start_time = time.time()
199201

200-
val_loss = evaluate(model, valid, criterion)
202+
# pr.disable()
203+
# s = StringIO.StringIO()
204+
# ps = pstats.Stats(pr, stream=s).sort_stats("time")
205+
# ps.print_stats()
206+
# print(s.getvalue())
207+
# val_loss = evaluate(model, valid, criterion)
201208

202209
print(
203-
'| end of epoch {:3d} | ms/batch {:5.2f} | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
204-
epoch, (time.time() - epoch_start_time) * 1000 / train.size(0), val_loss, math.exp(val_loss)
210+
'| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
211+
epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss)
205212
))
206213

207214
# The annealing schedule.

word_language_model/rnn_modules.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,22 @@ def initHidden(self, bsz):
8787
return Variable(self.h2h.weight.data.new(bsz, self.nhid).zero_())
8888

8989

90+
class StackedRNN(nn.Container):
91+
def __init__(self, rnnClass, ninp, nhid, nlayers):
92+
super(StackedRNN, self).__init__()
93+
self.nlayers = nlayers
94+
self.layers = []
95+
for i in range(nlayers):
96+
layer = rnnClass(ninp, nhid)
97+
self.layers += [layer]
98+
self.add_module('layer' + str(i), layer)
9099

100+
def __call__(self, hidden, input):
101+
output = input
102+
new_hidden = [None] * self.nlayers
103+
for i in range(self.nlayers):
104+
new_hidden[i], output = self.layers[i](hidden[i], output)
105+
return new_hidden, output
106+
107+
def initHidden(self, bsz):
108+
return [m.initHidden(bsz) for m in self.layers]

0 commit comments

Comments
 (0)