1+ import torch
2+ import torch .nn as nn
3+ import torchvision .models as models
4+ from torch .nn .utils .rnn import pack_padded_sequence
5+ import torch .nn .functional as F
6+ from torch .autograd import Variable
7+ import torch .nn .functional as F
8+ import random
9+
10+ device = "cuda:0"
11+ class EncoderCNN (nn .Module ):
12+ def __init__ (self , embed_size ):
13+ super (EncoderCNN , self ).__init__ ()
14+ resnet = models .resnet50 (pretrained = True )
15+ for param in resnet .parameters ():
16+ param .requires_grad_ (False )
17+
18+ modules = list (resnet .children ())[:- 1 ] # remove last fc layer
19+ self .resnet = nn .Sequential (* modules )
20+ self .linear = nn .Linear (resnet .fc .in_features , 50 )
21+
22+ def forward (self , images ):
23+
24+ features = self .resnet (images )
25+ features = features .reshape (features .size (0 ), - 1 )
26+ features = self .linear (features )
27+ return features
28+
29+ class DecoderRNN (nn .Module ):
30+ def __init__ (self , hidden_size , padding_index , vocab_size , embeddings ):
31+ """Set the hyper-parameters and build the layers."""
32+ super (DecoderRNN , self ).__init__ ()
33+ # Keep track of hidden_size for initialization of hidden state
34+ self .hidden_size = hidden_size
35+
36+ # Embedding layer that turns words into a vector of a specified size
37+ self .word_embeddings = nn .Embedding .from_pretrained (embeddings , freeze = True , padding_idx = 0 )
38+
39+ # The LSTM takes embedded word vectors (of a specified size) as input
40+ # and outputs hidden states of size hidden_dim
41+ self .lstm = nn .LSTM (input_size = 50 , \
42+ hidden_size = 1024 , # LSTM hidden units
43+ num_layers = 1 , # number of LSTM layer
44+ batch_first = True , # input & output will have batch size as 1st dimension
45+ dropout = 0 , # Not applying dropout
46+ bidirectional = False , # unidirectional LSTM
47+ )
48+
49+ # The linear layer that maps the hidden state output dimension
50+ # to the number of words we want as output, vocab_size
51+ self .linear_1 = nn .Linear (1024 , vocab_size )
52+
53+ def init_hidden_state (self , encoder_out ):
54+ """
55+ Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
56+ :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
57+ :return: hidden state, cell state
58+ """
59+
60+ h = encoder_out .reshape ((1 ,encoder_out .shape [0 ],encoder_out .shape [1 ])) # (batch_size, decoder_dim)
61+ c = encoder_out .reshape ((1 ,encoder_out .shape [0 ],encoder_out .shape [1 ]))
62+ return h , c
63+
64+
65+ def forward (self , features , captions ,caption_lengths ):
66+ """ Define the feedforward behavior of the model """
67+
68+ # Initialize the hidden state
69+ batch_size = features .shape [0 ] # features is of shape (batch_size, embed_size)
70+
71+ # Create embedded word vectors for each word in the captions
72+ inputs = self .word_embeddings (captions ) # embeddings new shape : (batch_size, captions length, embed_size)
73+
74+
75+ # Get the output and hidden state by passing the lstm over our word embeddings
76+ # the lstm takes in our embeddings and hidden state
77+ #h, c = self.init_hidden_state(features)
78+ inputs = torch .cat ((features .unsqueeze (1 ), inputs ), dim = 1 )
79+ lstm_out , self .hidden = self .lstm (inputs ) # lstm_out shape : (batch_size, caption length, hidden_size), Defaults to zeros if (h_0, c_0) is not provided.
80+
81+ lstm_out = lstm_out [:,1 :,:]
82+ # Fully connected layers
83+ outputs = self .linear_1 (lstm_out ) # outputs shape : (batch_size, caption length, vocab_size)
84+
85+ return outputs
86+
87+ def sample (self , features ):
88+ """Generate captions for given image features using greedy search."""
89+
90+ sampled_ids = []
91+ input = self .word_embeddings (torch .LongTensor ([1 ]).to (torch .device (device ))).reshape ((1 ,1 ,- 1 ))
92+ with torch .no_grad ():
93+ print (features .shape )
94+ _ ,state = self .lstm (features .reshape (1 ,1 ,- 1 ))
95+ for _ in range (15 ):
96+ hiddens , state = self .lstm (input , state ) # hiddens: (batch_size, 1, hidden_size)
97+ outputs = self .linear_1 (hiddens .squeeze (1 )) # outputs: (batch_size, vocab_size)
98+ _ , predicted = F .softmax (outputs ,dim = 1 ).cuda .max (1 ) if device == "cuda" else F .softmax (outputs ,dim = 1 ).max (1 ) # predicted: (batch_size)
99+ sampled_ids .append (predicted )
100+ inputs = self .word_embeddings (predicted ) # inputs: (batch_size, embed_size)
101+ input = inputs .unsqueeze (1 ).to (torch .device (device )) # inputs: (batch_size, 1, embed_size)
102+ if predicted == 2 :
103+ break
104+ sampled_ids = torch .stack (sampled_ids , 1 ) # sampled_ids: (batch_size, max_seq_length)
105+ return sampled_ids
106+
107+ def save (self , file_name ):
108+ """Save the classifier."""
109+
110+ torch .save (self .net .state_dict (), file_name )
111+
112+ def load (self , file_name ):
113+ """Load the classifier."""
114+
115+ # since our classifier is a nn.Module, we can load it using pytorch facilities (mapping it to the right device)
116+ self .net .load_state_dict (torch .load (file_name , map_location = self .device ))
117+
118+ def train (train_set , validation_set , lr , epochs , vocabulary ):
119+ device_t = torch .device (device )
120+ criterion = nn .CrossEntropyLoss (ignore_index = 0 ,reduction = "sum" ).cuda () if device == "cuda" else nn .CrossEntropyLoss (ignore_index = 0 ,reduction = "sum" )
121+
122+ # initializing some elements
123+ best_val_acc = - 1. # the best accuracy computed on the validation data
124+ best_epoch = - 1 # the epoch in which the best accuracy above was computed
125+
126+ encoder = EncoderCNN (50 )
127+ decoder = DecoderRNN (1024 ,0 ,len (vocabulary .word2id .keys ()),vocabulary .embeddings )
128+
129+ encoder .to (device_t )
130+ decoder .to (device_t )
131+
132+ # ensuring the classifier is in 'train' mode (pytorch)
133+ decoder .train ()
134+
135+ # creating the optimizer
136+ optimizer = torch .optim .Adam (list (decoder .parameters ()) + list (encoder .linear .parameters ()), lr )
137+
138+ # loop on epochs!
139+ for e in range (0 , epochs ):
140+
141+ # epoch-level stats (computed by accumulating mini-batch stats)
142+ epoch_train_acc = 0.
143+ epoch_train_loss = 0.
144+ epoch_num_train_examples = 0
145+
146+ for images ,captions ,captions_length ,captions_training in train_set :
147+ optimizer .zero_grad ()
148+
149+ # zeroing the memory areas that were storing previously computed gradients
150+ batch_num_train_examples = images .shape [0 ] # mini-batch size (it might be different from 'batch_size')
151+ epoch_num_train_examples += batch_num_train_examples
152+
153+ lengths = Variable (torch .LongTensor (captions_length ))
154+
155+ lengths = lengths .to (device_t )
156+ images = images .to (device_t )
157+ captions = captions .to (device_t ) # captions > (B, L)
158+ captions_training = captions_training .to (device_t ) # captions > (B, |L|-1) without end token
159+
160+ # computing the network output on the current mini-batch
161+ features = encoder (images )
162+ outputs = decoder (features , captions ,lengths ) # outputs > (B, L, |V|);
163+
164+ # (B, L, |V|) -> (B * L, |V|) and captions > (B * L)
165+ loss = criterion (outputs .reshape ((- 1 ,outputs .shape [2 ])), captions .reshape (- 1 ))
166+
167+ # computing gradients and updating the network weights
168+ loss .backward () # computing gradients
169+ optimizer .step () # updating weights
170+
171+ print (f"mini-batch:\t loss={ loss .item ():.4f} " )
172+ with torch .no_grad ():
173+ decoder .eval ()
174+ encoder .eval ()
175+ features = encoder (images )
176+ numb = random .randint (0 ,2 )
177+ caption = decoder .sample (features [numb ])
178+ print (vocabulary .rev_translate (captions [numb ]))
179+ print (vocabulary .rev_translate (caption [0 ]))
180+ decoder .train ()
181+ encoder .train ()
182+
183+ # Example of usage
184+ if __name__ == "__main__" :
185+ from Vocabulary import Vocabulary
186+ from Dataset import MyDataset
187+ from torch .utils .data import DataLoader
188+ ds = MyDataset ("./dataset" , percentage = 1 )
189+ ds = ds .get_fraction_of_dataset (percentage = 100 )
190+ # use dataloader facilities which requires a preprocessed dataset
191+ v = Vocabulary (ds ,reload = True )
192+
193+ dataloader = DataLoader (ds , batch_size = 30 ,
194+ shuffle = True , num_workers = 4 , collate_fn = lambda data : ds .pack_minibatch_training (data ,v ))
195+
196+ train (dataloader , dataloader , 1e-3 , 400 , v )
0 commit comments