Skip to content

Commit 41b035f

Browse files
authored
Merge pull request #1003 from sudomaze/feat/add_siamese_network_example
Implemented a Siamese Network Example
2 parents 648c0bd + 332e138 commit 41b035f

File tree

7 files changed

+326
-1
lines changed

7 files changed

+326
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ docs/venv
1717

1818
# vi backups
1919
*~
20+
21+
# development
22+
.vscode

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ If you're new we encourage you to take a look at issues tagged with [good first
2626

2727
5. Verify that there are no issues in your doc build. You can check preview locally
2828
by installing [sphinx-serve](https://pypi.org/project/sphinx-serve/) and
29-
then running `sphinx-serve -d build`.
29+
then running `sphinx-serve -b build`.
3030

3131
5. Ensure your test passes locally
3232
6. If you haven't already, complete the Contributor License Agreement ("CLA").

docs/source/index.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ experiment with PyTorch.
1717

1818
---
1919

20+
Measuring Similarity using Siamese Network
21+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
22+
23+
This example demonstrates how to measure similarity between two images
24+
using `Siamese network <https://en.wikipedia.org/wiki/Siamese_neural_network>`__
25+
on the `MNIST <https://en.wikipedia.org/wiki/MNIST_database>`__ database.
26+
27+
`GO TO EXAMPLE <https://github.com/pytorch/examples/blob/main/siamese_network>`__ :opticon:`link-external`
28+
29+
---
30+
2031
Word-level Language Modeling using RNN and Transformer
2132
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2233

run_python_examples.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ function regression() {
110110
python main.py --epochs 1 $CUDA_FLAG || error "regression failed"
111111
}
112112

113+
function siamese_network() {
114+
start
115+
python main.py --epochs 1 --dry-run || error "siamese network example failed"
116+
}
117+
113118
function reinforcement_learning() {
114119
start
115120
python reinforce.py || error "reinforcement learning reinforce failed"
@@ -193,6 +198,7 @@ function run_all() {
193198
mnist_hogwild
194199
regression
195200
reinforcement_learning
201+
siamese_network
196202
super_resolution
197203
time_sequence_prediction
198204
vae

siamese_network/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Siamese Network Example
2+
3+
```bash
4+
pip install -r requirements.txt
5+
python main.py
6+
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2
7+
```

siamese_network/main.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
from __future__ import print_function
2+
import argparse, random, copy
3+
import numpy as np
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import torch.optim as optim
9+
import torchvision
10+
from torch.utils.data import Dataset
11+
from torchvision import datasets
12+
from torchvision import transforms as T
13+
from torch.optim.lr_scheduler import StepLR
14+
15+
16+
class SiameseNetwork(nn.Module):
17+
"""
18+
Siamese network for image similarity estimation.
19+
The network is composed of two identical networks, one for each input.
20+
The output of each network is concatenated and passed to a linear layer.
21+
The output of the linear layer passed through a sigmoid function.
22+
`"FaceNet" <https://arxiv.org/pdf/1503.03832.pdf>`_ is a variant of the Siamese network.
23+
This implementation varies from FaceNet as we use the `ResNet-18` model from
24+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ as our feature extractor.
25+
In addition, we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick.
26+
"""
27+
def __init__(self):
28+
super(SiameseNetwork, self).__init__()
29+
# get resnet model
30+
self.resnet = torchvision.models.resnet18(pretrained=False)
31+
32+
# over-write the first conv layer to be able to read MNIST images
33+
# as resnet18 reads (3,x,x) where 3 is RGB channels
34+
# whereas MNIST has (1,x,x) where 1 is a gray-scale channel
35+
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
36+
self.fc_in_features = self.resnet.fc.in_features
37+
38+
# remove the last layer of resnet18 (linear layer which is before avgpool layer)
39+
self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))
40+
41+
# add linear layers to compare between the features of the two images
42+
self.fc = nn.Sequential(
43+
nn.Linear(self.fc_in_features * 2, 256),
44+
nn.ReLU(inplace=True),
45+
nn.Linear(256, 1),
46+
)
47+
48+
self.sigmoid = nn.Sigmoid()
49+
50+
# initialize the weights
51+
self.resnet.apply(self.init_weights)
52+
self.fc.apply(self.init_weights)
53+
54+
def init_weights(self, m):
55+
if isinstance(m, nn.Linear):
56+
torch.nn.init.xavier_uniform(m.weight)
57+
m.bias.data.fill_(0.01)
58+
59+
def forward_once(self, x):
60+
output = self.resnet(x)
61+
output = output.view(output.size()[0], -1)
62+
return output
63+
64+
def forward(self, input1, input2):
65+
# get two images' features
66+
output1 = self.forward_once(input1)
67+
output2 = self.forward_once(input2)
68+
69+
# concatenate both images' features
70+
output = torch.cat((output1, output2), 1)
71+
72+
# pass the concatenation to the linear layers
73+
output = self.fc(output)
74+
75+
# pass the out of the linear layers to sigmoid layer
76+
output = self.sigmoid(output)
77+
78+
return output
79+
80+
class APP_MATCHER(Dataset):
81+
def __init__(self, root, train, download=False):
82+
super(APP_MATCHER, self).__init__()
83+
84+
# get MNIST dataset
85+
self.dataset = datasets.MNIST(root, train=train, download=download)
86+
87+
# as `self.dataset.data`'s shape is (Nx28x28), where N is the number of
88+
# examples in MNIST dataset, a single example has the dimensions of
89+
# (28x28) for (WxH), where W and H are the width and the height of the image.
90+
# However, every example should have (CxWxH) dimensions where C is the number
91+
# of channels to be passed to the network. As MNIST contains gray-scale images,
92+
# we add an additional dimension to corresponds to the number of channels.
93+
self.data = self.dataset.data.unsqueeze(1).clone()
94+
95+
self.group_examples()
96+
97+
def group_examples(self):
98+
"""
99+
To ease the accessibility of data based on the class, we will use `group_examples` to group
100+
examples based on class.
101+
102+
Every key in `grouped_examples` corresponds to a class in MNIST dataset. For every key in
103+
`grouped_examples`, every value will conform to all of the indices for the MNIST
104+
dataset examples that correspond to that key.
105+
"""
106+
107+
# get the targets from MNIST dataset
108+
np_arr = np.array(self.dataset.targets.clone())
109+
110+
# group examples based on class
111+
self.grouped_examples = {}
112+
for i in range(0,10):
113+
self.grouped_examples[i] = np.where((np_arr==i))[0]
114+
115+
def __len__(self):
116+
return self.data.shape[0]
117+
118+
def __getitem__(self, index):
119+
"""
120+
For every example, we will select two images. There are two cases,
121+
positive and negative examples. For positive examples, we will have two
122+
images from the same class. For negative examples, we will have two images
123+
from different classes.
124+
125+
Given an index, if the index is even, we will pick the second image from the same class,
126+
but it won't be the same image we chose for the first class. This is used to ensure the positive
127+
example isn't trivial as the network would easily distinguish the similarity between same images. However,
128+
if the network were given two different images from the same class, the network will need to learn
129+
the similarity between two different images representing the same class. If the index is odd, we will
130+
pick the second image from a different class than the first image.
131+
"""
132+
133+
# pick some random class for the first image
134+
selected_class = random.randint(0, 9)
135+
136+
# pick a random index for the first image in the grouped indices based of the label
137+
# of the class
138+
random_index_1 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1)
139+
140+
# pick the index to get the first image
141+
index_1 = self.grouped_examples[selected_class][random_index_1]
142+
143+
# get the first image
144+
image_1 = self.data[index_1].clone().float()
145+
146+
# same class
147+
if index % 2 == 0:
148+
# pick a random index for the second image
149+
random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1)
150+
151+
# ensure that the index of the second image isn't the same as the first image
152+
while random_index_2 == random_index_1:
153+
random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1)
154+
155+
# pick the index to get the second image
156+
index_2 = self.grouped_examples[selected_class][random_index_2]
157+
158+
# get the second image
159+
image_2 = self.data[index_2].clone().float()
160+
161+
# set the label for this example to be positive (1)
162+
target = torch.tensor(1, dtype=torch.float)
163+
164+
# different class
165+
else:
166+
# pick a random class
167+
other_selected_class = random.randint(0, 9)
168+
169+
# ensure that the class of the second image isn't the same as the first image
170+
while other_selected_class == selected_class:
171+
other_selected_class = random.randint(0, 9)
172+
173+
174+
# pick a random index for the second image in the grouped indices based of the label
175+
# of the class
176+
random_index_2 = random.randint(0, self.grouped_examples[other_selected_class].shape[0]-1)
177+
178+
# pick the index to get the second image
179+
index_2 = self.grouped_examples[other_selected_class][random_index_2]
180+
181+
# get the second image
182+
image_2 = self.data[index_2].clone().float()
183+
184+
# set the label for this example to be negative (0)
185+
target = torch.tensor(0, dtype=torch.float)
186+
187+
return image_1, image_2, target
188+
189+
190+
def train(args, model, device, train_loader, optimizer, epoch):
191+
model.train()
192+
193+
# we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick.
194+
criterion = nn.BCELoss()
195+
196+
for batch_idx, (images_1, images_2, targets) in enumerate(train_loader):
197+
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device)
198+
optimizer.zero_grad()
199+
outputs = model(images_1, images_2).squeeze()
200+
loss = criterion(outputs, targets)
201+
loss.backward()
202+
optimizer.step()
203+
if batch_idx % args.log_interval == 0:
204+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
205+
epoch, batch_idx * len(images_1), len(train_loader.dataset),
206+
100. * batch_idx / len(train_loader), loss.item()))
207+
if args.dry_run:
208+
break
209+
210+
211+
def test(model, device, test_loader):
212+
model.eval()
213+
test_loss = 0
214+
correct = 0
215+
216+
# we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick.
217+
criterion = nn.BCELoss()
218+
219+
with torch.no_grad():
220+
for (images_1, images_2, targets) in test_loader:
221+
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device)
222+
outputs = model(images_1, images_2).squeeze()
223+
test_loss += criterion(outputs, targets).sum().item() # sum up batch loss
224+
pred = torch.where(outputs > 0.5, 1, 0) # get the index of the max log-probability
225+
correct += pred.eq(targets.view_as(pred)).sum().item()
226+
227+
test_loss /= len(test_loader.dataset)
228+
229+
# for the 1st epoch, the average loss is 0.0001 and the accuracy 97-98%
230+
# using default settings. After completing the 10th epoch, the average
231+
# loss is 0.0000 and the accuracy 99.5-100% using default settings.
232+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
233+
test_loss, correct, len(test_loader.dataset),
234+
100. * correct / len(test_loader.dataset)))
235+
236+
237+
def main():
238+
# Training settings
239+
parser = argparse.ArgumentParser(description='PyTorch Siamese network Example')
240+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
241+
help='input batch size for training (default: 64)')
242+
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
243+
help='input batch size for testing (default: 1000)')
244+
parser.add_argument('--epochs', type=int, default=14, metavar='N',
245+
help='number of epochs to train (default: 14)')
246+
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
247+
help='learning rate (default: 1.0)')
248+
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
249+
help='Learning rate step gamma (default: 0.7)')
250+
parser.add_argument('--no-cuda', action='store_true', default=False,
251+
help='disables CUDA training')
252+
parser.add_argument('--dry-run', action='store_true', default=False,
253+
help='quickly check a single pass')
254+
parser.add_argument('--seed', type=int, default=1, metavar='S',
255+
help='random seed (default: 1)')
256+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
257+
help='how many batches to wait before logging training status')
258+
parser.add_argument('--save-model', action='store_true', default=False,
259+
help='For Saving the current Model')
260+
args = parser.parse_args()
261+
262+
use_cuda = not args.no_cuda and torch.cuda.is_available()
263+
264+
torch.manual_seed(args.seed)
265+
266+
device = torch.device("cuda" if use_cuda else "cpu")
267+
268+
train_kwargs = {'batch_size': args.batch_size}
269+
test_kwargs = {'batch_size': args.test_batch_size}
270+
if use_cuda:
271+
cuda_kwargs = {'num_workers': 1,
272+
'pin_memory': True,
273+
'shuffle': True}
274+
train_kwargs.update(cuda_kwargs)
275+
test_kwargs.update(cuda_kwargs)
276+
277+
train_dataset = APP_MATCHER('../data', train=True, download=True)
278+
test_dataset = APP_MATCHER('../data', train=False)
279+
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
280+
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
281+
282+
model = SiameseNetwork().to(device)
283+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
284+
285+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
286+
for epoch in range(1, args.epochs + 1):
287+
train(args, model, device, train_loader, optimizer, epoch)
288+
test(model, device, test_loader)
289+
scheduler.step()
290+
291+
if args.save_model:
292+
torch.save(model.state_dict(), "siamese_network.pt")
293+
294+
295+
if __name__ == '__main__':
296+
main()

siamese_network/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch
2+
torchvision

0 commit comments

Comments
 (0)