@@ -17,7 +17,7 @@ def __init__(self, embed_size):
1717
1818 modules = list (resnet .children ())[:- 1 ] # remove last fc layer
1919 self .resnet = nn .Sequential (* modules )
20- self .linear = nn .Linear (resnet .fc .in_features , 1024 )
20+ self .linear = nn .Linear (resnet .fc .in_features , 50 )
2121
2222 def forward (self , images ):
2323
@@ -48,13 +48,7 @@ def __init__(self, hidden_size, padding_index, vocab_size, embeddings ):
4848
4949 # The linear layer that maps the hidden state output dimension
5050 # to the number of words we want as output, vocab_size
51- self .linear_1 = nn .Linear (1024 , vocab_size + 4096 )
52- self .relu_1 = nn .ReLU ()
53- self .linear_2 = nn .Linear (vocab_size + 4096 , vocab_size + 2048 )
54- self .relu_2 = nn .ReLU ()
55- self .linear_3 = nn .Linear (vocab_size + 2048 , vocab_size + 1024 )
56- self .relu_3 = nn .ReLU ()
57- self .linear_4 = nn .Linear (vocab_size + 1024 , vocab_size )
51+ self .linear_1 = nn .Linear (1024 , vocab_size )
5852
5953 def init_hidden_state (self , encoder_out ):
6054 """
@@ -80,18 +74,14 @@ def forward(self, features, captions,caption_lengths):
8074
8175 # Get the output and hidden state by passing the lstm over our word embeddings
8276 # the lstm takes in our embeddings and hidden state
83- h , c = self .init_hidden_state (features )
77+ #h, c = self.init_hidden_state(features)
78+ inputs = torch .cat ((features .unsqueeze (1 ), inputs ), dim = 1 )
79+ lstm_out , self .hidden = self .lstm (inputs ) # lstm_out shape : (batch_size, caption length, hidden_size), Defaults to zeros if (h_0, c_0) is not provided.
8480
85- lstm_out , self .hidden = self .lstm (inputs , (h , c )) # lstm_out shape : (batch_size, caption length, hidden_size), Defaults to zeros if (h_0, c_0) is not provided.
86-
81+ lstm_out = lstm_out [:,1 :,:]
8782 # Fully connected layers
8883 outputs = self .linear_1 (lstm_out ) # outputs shape : (batch_size, caption length, vocab_size)
89- outputs = self .relu_1 (outputs )
90- outputs = self .linear_2 (outputs )
91- outputs = self .relu_2 (outputs )
92- outputs = self .linear_3 (outputs )
93- outputs = self .relu_3 (outputs )
94- outputs = self .linear_4 (outputs )
84+
9585 return outputs
9686
9787 def sample (self , features ):
@@ -101,16 +91,10 @@ def sample(self, features):
10191 input = self .word_embeddings (torch .LongTensor ([1 ]).to (torch .device (device ))).reshape ((1 ,1 ,- 1 ))
10292 with torch .no_grad ():
10393 print (features .shape )
104- state = self .init_hidden_state (features .reshape (( 1 , - 1 ) ))
94+ _ , state = self .lstm (features .reshape (1 , 1 , - 1 ))
10595 for _ in range (15 ):
10696 hiddens , state = self .lstm (input , state ) # hiddens: (batch_size, 1, hidden_size)
10797 outputs = self .linear_1 (hiddens .squeeze (1 )) # outputs: (batch_size, vocab_size)
108- outputs = self .relu_1 (outputs )
109- outputs = self .linear_2 (outputs )
110- outputs = self .relu_2 (outputs )
111- outputs = self .linear_3 (outputs )
112- outputs = self .relu_3 (outputs )
113- outputs = self .linear_4 (outputs )
11498 _ , predicted = F .softmax (outputs ,dim = 1 ).cuda .max (1 ) if device == "cuda" else F .softmax (outputs ,dim = 1 ).max (1 ) # predicted: (batch_size)
11599 sampled_ids .append (predicted )
116100 inputs = self .word_embeddings (predicted ) # inputs: (batch_size, embed_size)
@@ -189,7 +173,7 @@ def train(train_set, validation_set, lr, epochs, vocabulary):
189173 decoder .eval ()
190174 encoder .eval ()
191175 features = encoder (images )
192- numb = random .randint (0 ,9 )
176+ numb = random .randint (0 ,2 )
193177 caption = decoder .sample (features [numb ])
194178 print (vocabulary .rev_translate (captions [numb ]))
195179 print (vocabulary .rev_translate (caption [0 ]))
@@ -202,11 +186,11 @@ def train(train_set, validation_set, lr, epochs, vocabulary):
202186 from Dataset import MyDataset
203187 from torch .utils .data import DataLoader
204188 ds = MyDataset ("./dataset" , percentage = 1 )
205- ds = ds .get_fraction_of_dataset (percentage = 12 )
189+ ds = ds .get_fraction_of_dataset (percentage = 100 )
206190 # use dataloader facilities which requires a preprocessed dataset
207191 v = Vocabulary (ds ,reload = True )
208192
209- dataloader = DataLoader (ds , batch_size = 10 ,
193+ dataloader = DataLoader (ds , batch_size = 30 ,
210194 shuffle = True , num_workers = 4 , collate_fn = lambda data : ds .pack_minibatch_training (data ,v ))
211195
212196 train (dataloader , dataloader , 1e-3 , 400 , v )
0 commit comments