Skip to content

Commit 49a8ea2

Browse files
committed
Update NeuralNet.py
1 parent 8e2020a commit 49a8ea2

File tree

1 file changed

+11
-27
lines changed

1 file changed

+11
-27
lines changed

light_version/NeuralNet.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, embed_size):
1717

1818
modules = list(resnet.children())[:-1] # remove last fc layer
1919
self.resnet = nn.Sequential(*modules)
20-
self.linear = nn.Linear(resnet.fc.in_features, 1024)
20+
self.linear = nn.Linear(resnet.fc.in_features, 50)
2121

2222
def forward(self, images):
2323

@@ -48,13 +48,7 @@ def __init__(self, hidden_size, padding_index, vocab_size, embeddings ):
4848

4949
# The linear layer that maps the hidden state output dimension
5050
# to the number of words we want as output, vocab_size
51-
self.linear_1 = nn.Linear(1024, vocab_size + 4096)
52-
self.relu_1 = nn.ReLU()
53-
self.linear_2 = nn.Linear(vocab_size + 4096, vocab_size + 2048)
54-
self.relu_2 = nn.ReLU()
55-
self.linear_3 = nn.Linear(vocab_size + 2048, vocab_size + 1024)
56-
self.relu_3 = nn.ReLU()
57-
self.linear_4 = nn.Linear(vocab_size + 1024, vocab_size)
51+
self.linear_1 = nn.Linear(1024, vocab_size)
5852

5953
def init_hidden_state(self, encoder_out):
6054
"""
@@ -80,18 +74,14 @@ def forward(self, features, captions,caption_lengths):
8074

8175
# Get the output and hidden state by passing the lstm over our word embeddings
8276
# the lstm takes in our embeddings and hidden state
83-
h, c = self.init_hidden_state(features)
77+
#h, c = self.init_hidden_state(features)
78+
inputs = torch.cat((features.unsqueeze(1), inputs), dim=1)
79+
lstm_out, self.hidden = self.lstm(inputs) # lstm_out shape : (batch_size, caption length, hidden_size), Defaults to zeros if (h_0, c_0) is not provided.
8480

85-
lstm_out, self.hidden = self.lstm(inputs, (h, c)) # lstm_out shape : (batch_size, caption length, hidden_size), Defaults to zeros if (h_0, c_0) is not provided.
86-
81+
lstm_out = lstm_out[:,1:,:]
8782
# Fully connected layers
8883
outputs = self.linear_1(lstm_out) # outputs shape : (batch_size, caption length, vocab_size)
89-
outputs = self.relu_1(outputs)
90-
outputs = self.linear_2(outputs)
91-
outputs = self.relu_2(outputs)
92-
outputs = self.linear_3(outputs)
93-
outputs = self.relu_3(outputs)
94-
outputs = self.linear_4(outputs)
84+
9585
return outputs
9686

9787
def sample(self, features):
@@ -101,16 +91,10 @@ def sample(self, features):
10191
input = self.word_embeddings(torch.LongTensor([1]).to(torch.device(device))).reshape((1,1,-1))
10292
with torch.no_grad():
10393
print(features.shape)
104-
state = self.init_hidden_state(features.reshape((1,-1)))
94+
_ ,state = self.lstm(features.reshape(1,1,-1))
10595
for _ in range(15):
10696
hiddens, state = self.lstm(input, state) # hiddens: (batch_size, 1, hidden_size)
10797
outputs = self.linear_1(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
108-
outputs = self.relu_1(outputs)
109-
outputs = self.linear_2(outputs)
110-
outputs = self.relu_2(outputs)
111-
outputs = self.linear_3(outputs)
112-
outputs = self.relu_3(outputs)
113-
outputs = self.linear_4(outputs)
11498
_, predicted = F.softmax(outputs,dim=1).cuda.max(1) if device == "cuda" else F.softmax(outputs,dim=1).max(1) # predicted: (batch_size)
11599
sampled_ids.append(predicted)
116100
inputs = self.word_embeddings(predicted) # inputs: (batch_size, embed_size)
@@ -189,7 +173,7 @@ def train(train_set, validation_set, lr, epochs, vocabulary):
189173
decoder.eval()
190174
encoder.eval()
191175
features = encoder(images)
192-
numb = random.randint(0,9)
176+
numb = random.randint(0,2)
193177
caption = decoder.sample(features[numb])
194178
print(vocabulary.rev_translate(captions[numb]))
195179
print(vocabulary.rev_translate(caption[0]))
@@ -202,11 +186,11 @@ def train(train_set, validation_set, lr, epochs, vocabulary):
202186
from Dataset import MyDataset
203187
from torch.utils.data import DataLoader
204188
ds = MyDataset("./dataset", percentage=1)
205-
ds = ds.get_fraction_of_dataset(percentage=12)
189+
ds = ds.get_fraction_of_dataset(percentage=100)
206190
# use dataloader facilities which requires a preprocessed dataset
207191
v = Vocabulary(ds,reload=True)
208192

209-
dataloader = DataLoader(ds, batch_size=10,
193+
dataloader = DataLoader(ds, batch_size=30,
210194
shuffle=True, num_workers=4, collate_fn = lambda data: ds.pack_minibatch_training(data,v))
211195

212196
train(dataloader, dataloader, 1e-3, 400, v)

0 commit comments

Comments
 (0)