|
| 1 | +############################################################################### |
| 2 | +# Language Modeling on Penn Tree Bank |
| 3 | +############################################################################### |
| 4 | + |
| 5 | +import argparse |
| 6 | +import time |
| 7 | +import math |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +from torch.autograd import Variable |
| 12 | + |
| 13 | +from rnn_modules import * |
| 14 | +import data |
| 15 | + |
| 16 | +parser = argparse.ArgumentParser(description='PyTorch PTB Language Model') |
| 17 | + |
| 18 | +# Data parameters |
| 19 | +parser.add_argument('-data' , type=str, default='./data/penn', help='Location of the data corpus' ) |
| 20 | +# Model parameters. |
| 21 | +parser.add_argument('-model' , type=str, default='RNN' , help='Type of recurrent net. RNN, LSTM, or GRU.' ) |
| 22 | +parser.add_argument('-emsize' , type=int, default=200 , help='Size of word embeddings' ) |
| 23 | +parser.add_argument('-nhid' , type=int, default=200 , help='Number of hidden units per layer.' ) |
| 24 | +# Optimization parameters. |
| 25 | +parser.add_argument('-lr' , type=float, default=20 , help='Initial learning rate.' ) |
| 26 | +parser.add_argument('-clip' , type=float, default=0.5 , help='Gradient clipping.' ) |
| 27 | +parser.add_argument('-maxepoch' , type=int, default=6 , help='Upper epoch limit.' ) |
| 28 | +parser.add_argument('-batchsize' , type=int, default=20 , help='Batch size.' ) |
| 29 | +parser.add_argument('-bptt' , type=int, default=20 , help='Sequence length.' ) |
| 30 | +# Device parameters. |
| 31 | +parser.add_argument('-seed' , type=int, default=1111 , help='Random seed.' ) |
| 32 | +parser.add_argument('-cuda' , action='store_true' , help='Use CUDA.' ) |
| 33 | +# Misc parameters. |
| 34 | +parser.add_argument('-reportint' , type=int, default=1000 , help='Report interval.' ) |
| 35 | +parser.add_argument('-save' , type=str, default='model.pt' , help='Path to save the final model.' ) |
| 36 | +args = parser.parse_args() |
| 37 | + |
| 38 | +# Set the random seed manually for reproducibility. |
| 39 | +torch.manual_seed(args.seed) |
| 40 | +# If the GPU is enabled, do some plumbing. |
| 41 | + |
| 42 | +if torch.cuda.is_available() and not args.cuda: |
| 43 | + print("WARNING: You have a CUDA device, so you should probably run with -cuda") |
| 44 | + |
| 45 | +############################################################################### |
| 46 | +## LOAD DATA |
| 47 | +############################################################################### |
| 48 | + |
| 49 | +corpus = data.Corpus(args.data) |
| 50 | + |
| 51 | +def batchify(data, bsz, bptt): |
| 52 | + nbatch = int(math.floor(data.size(0) / bsz / bptt)) |
| 53 | + data = data.narrow(0, 0, nbatch * bptt * bsz) |
| 54 | + data = data.view(bsz, -1).t().contiguous() |
| 55 | + if args.cuda: |
| 56 | + data = data.cuda() |
| 57 | + return data |
| 58 | + |
| 59 | +train = batchify(corpus.train, args.batchsize, args.bptt) |
| 60 | +valid = batchify(corpus.valid, 10, 1) |
| 61 | +test = batchify(corpus.test, 10, 1) |
| 62 | +train = train[:10000] |
| 63 | +valid = valid[:100] |
| 64 | + |
| 65 | +bptt = args.bptt |
| 66 | +bsz = args.batchsize |
| 67 | + |
| 68 | +############################################################################### |
| 69 | +# MAKE MODEL |
| 70 | +############################################################################### |
| 71 | + |
| 72 | +initrange = 0.1 |
| 73 | + |
| 74 | +class RNNModel(nn.Container): |
| 75 | + """A container module with an encoder, an RNN (one of several flavors), |
| 76 | + and a decoder. Runs one RNN step at a time. |
| 77 | + """ |
| 78 | + |
| 79 | + @staticmethod |
| 80 | + def name2module(name): |
| 81 | + if name == 'RNN': |
| 82 | + return RNN |
| 83 | + elif name == 'LSTM': |
| 84 | + return LSTM |
| 85 | + elif name == 'GRU': |
| 86 | + return GRU |
| 87 | + else: |
| 88 | + error("Unknown RNN module: " + name) |
| 89 | + |
| 90 | + def __init__(self, rnnType, ntoken, ninp, nhid): |
| 91 | + rnnModule = RNNModel.name2module(rnnType) |
| 92 | + super(RNNModel, self).__init__( |
| 93 | + encoder = nn.sparse.Embedding(ntoken, ninp), |
| 94 | + rnn = rnnModule(ninp, nhid), |
| 95 | + decoder = nn.Linear(nhid, ntoken), |
| 96 | + ) |
| 97 | + |
| 98 | + # FIXME: is this better than the standard init? probably |
| 99 | + # FIXME: we need better reset_parameters methods in stdlib |
| 100 | + self.encoder.weight.data.uniform_(-initrange, initrange) |
| 101 | + self.decoder.bias.data.fill_(0) |
| 102 | + self.decoder.weight.data.uniform_(-initrange, initrange) |
| 103 | + |
| 104 | + def __call__(self, hidden, input): |
| 105 | + emb = self.encoder(input) |
| 106 | + hidden, output = self.rnn(hidden, emb) |
| 107 | + decoded = self.decoder(output) |
| 108 | + return hidden, decoded |
| 109 | + |
| 110 | + def initHidden(self, bsz): |
| 111 | + return self.rnn.initHidden(bsz) |
| 112 | + |
| 113 | +model = RNNModel(args.model, corpus.dic.ntokens(), args.emsize, args.nhid) |
| 114 | +if args.cuda: |
| 115 | + model.cuda() |
| 116 | + |
| 117 | +criterion = nn.CrossEntropyLoss() |
| 118 | + |
| 119 | +######################################## |
| 120 | +# TRAINING |
| 121 | +######################################## |
| 122 | + |
| 123 | +lr = args.lr |
| 124 | +clip = args.clip |
| 125 | +reportinterval = args.reportint |
| 126 | + |
| 127 | +# Perform the forward pass only. |
| 128 | +def evaluate(model, data, criterion): |
| 129 | + loss = 0 |
| 130 | + hidden = model.initHidden(data.size(1)) |
| 131 | + # Loop over validation data. |
| 132 | + for i in range(0, data.size(0) - 1): |
| 133 | + hidden, output = model(hidden, Variable(data[i], requires_grad=False)) |
| 134 | + loss += criterion(output, Variable(data[i+1], requires_grad=False)).data[0] |
| 135 | + |
| 136 | + return loss / data.size(0) |
| 137 | + |
| 138 | +# simple gradient clipping, using the total norm of the gradient |
| 139 | +def clipGradient(model, clip): |
| 140 | + totalnorm = 0 |
| 141 | + for p in model.parameters(): |
| 142 | + modulenorm = p.grad.norm() |
| 143 | + totalnorm += modulenorm ** 2 |
| 144 | + totalnorm = math.sqrt(totalnorm) |
| 145 | + return min(1, args.clip / (totalnorm + 1e-6)) |
| 146 | + |
| 147 | +# Between bptt intervals, we want to maintain the hidden state data |
| 148 | +# but don't want to backprop gradients across bptt intervals. |
| 149 | +# So we have to rewrap the hidden state in a fresh Variable. |
| 150 | +def repackageHidden(h): |
| 151 | + if type(h) == Variable: |
| 152 | + return Variable(h.data) |
| 153 | + else: |
| 154 | + return tuple(repackageVariable(v) for v in h) |
| 155 | + |
| 156 | +# Loop over epochs. |
| 157 | +prev_loss = None |
| 158 | +for epoch in range(1, args.maxepoch+1): |
| 159 | + total_loss = 0 |
| 160 | + epoch_start_time = time.time() |
| 161 | + # Start with an initial hidden state. |
| 162 | + hidden = model.initHidden(bsz) |
| 163 | + # Loop over the training data. |
| 164 | + loss = 0 |
| 165 | + i = 0 |
| 166 | + model.zero_grad() |
| 167 | + |
| 168 | + total_loss = 0 |
| 169 | + start_time = epoch_start_time = time.time() |
| 170 | + while i < train.size(0) - 1: |
| 171 | + hidden, output = model(hidden, Variable(train[i], requires_grad=False)) |
| 172 | + loss += criterion(output, Variable(train[i+1], requires_grad=False)) |
| 173 | + i += 1 |
| 174 | + |
| 175 | + if i % bptt == 0: |
| 176 | + loss.backward() |
| 177 | + |
| 178 | + clipped_lr = lr * clipGradient(model, args.clip) |
| 179 | + |
| 180 | + for p in model.parameters(): |
| 181 | + p.data.sub_(p.grad.mul(clipped_lr)) |
| 182 | + |
| 183 | + hidden = repackageHidden(hidden) |
| 184 | + model.zero_grad() |
| 185 | + total_loss += loss.data[0] |
| 186 | + loss = 0 |
| 187 | + |
| 188 | + if i % reportinterval == 0: |
| 189 | + cur_loss = total_loss / reportinterval |
| 190 | + elapsed = time.time() - start_time |
| 191 | + print( |
| 192 | + ('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.6f} | ms/batch {:5.2f} | ' |
| 193 | + + 'train loss {:5.2f} | train ppl {:8.2f}').format( |
| 194 | + epoch, i, train.size(0), lr, elapsed * 1000 / reportinterval * bsz, |
| 195 | + cur_loss, math.exp(cur_loss) |
| 196 | + )) |
| 197 | + total_loss = 0 |
| 198 | + start_time = time.time() |
| 199 | + |
| 200 | + val_loss = evaluate(model, valid, criterion) |
| 201 | + |
| 202 | + print( |
| 203 | + '| end of epoch {:3d} | ms/batch {:5.2f} | valid loss {:5.2f} | valid ppl {:8.2f}'.format( |
| 204 | + epoch, (time.time() - epoch_start_time) * 1000 / train.size(0), val_loss, math.exp(val_loss) |
| 205 | + )) |
| 206 | + |
| 207 | + # The annealing schedule. |
| 208 | + if prev_loss and val_loss > prev_loss: |
| 209 | + lr = lr / 4 |
| 210 | + |
| 211 | + prev_loss = val_loss |
| 212 | + |
| 213 | +# Run on test data. |
| 214 | +test_loss = evaluate(model, test, criterion) |
| 215 | +print( |
| 216 | + '| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( |
| 217 | + test_loss, math.exp(test_loss) |
| 218 | +)) |
| 219 | + |
| 220 | +if args.save != '' : |
| 221 | + with open(args.save, 'wb') as f: |
| 222 | + torch.save(model, f) |
0 commit comments