|
| 1 | +from __future__ import print_function |
| 2 | +import argparse, random, copy |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +import torch.nn.functional as F |
| 8 | +import torch.optim as optim |
| 9 | +import torchvision |
| 10 | +from torch.utils.data import Dataset |
| 11 | +from torchvision import datasets |
| 12 | +from torchvision import transforms as T |
| 13 | +from torch.optim.lr_scheduler import StepLR |
| 14 | + |
| 15 | + |
| 16 | +class SiameseNetwork(nn.Module): |
| 17 | + """ |
| 18 | + Siamese network for image similarity estimation. |
| 19 | + The network is composed of two identical networks, one for each input. |
| 20 | + The output of each network is concatenated and passed to a linear layer. |
| 21 | + The output of the linear layer passed through a sigmoid function. |
| 22 | + `"FaceNet" <https://arxiv.org/pdf/1503.03832.pdf>`_ is a variant of the Siamese network. |
| 23 | + This implementation varies from FaceNet as we use the `ResNet-18` model from |
| 24 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ as our feature extractor. |
| 25 | + In addition, we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. |
| 26 | + """ |
| 27 | + def __init__(self): |
| 28 | + super(SiameseNetwork, self).__init__() |
| 29 | + # get resnet model |
| 30 | + self.resnet = torchvision.models.resnet18(pretrained=False) |
| 31 | + |
| 32 | + # over-write the first conv layer to be able to read MNIST images |
| 33 | + # as resnet18 reads (3,x,x) where 3 is RGB channels |
| 34 | + # whereas MNIST has (1,x,x) where 1 is a gray-scale channel |
| 35 | + self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) |
| 36 | + self.fc_in_features = self.resnet.fc.in_features |
| 37 | + |
| 38 | + # remove the last layer of resnet18 (linear layer which is before avgpool layer) |
| 39 | + self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])) |
| 40 | + |
| 41 | + # add linear layers to compare between the features of the two images |
| 42 | + self.fc = nn.Sequential( |
| 43 | + nn.Linear(self.fc_in_features * 2, 256), |
| 44 | + nn.ReLU(inplace=True), |
| 45 | + nn.Linear(256, 1), |
| 46 | + ) |
| 47 | + |
| 48 | + self.sigmoid = nn.Sigmoid() |
| 49 | + |
| 50 | + # initialize the weights |
| 51 | + self.resnet.apply(self.init_weights) |
| 52 | + self.fc.apply(self.init_weights) |
| 53 | + |
| 54 | + def init_weights(self, m): |
| 55 | + if isinstance(m, nn.Linear): |
| 56 | + torch.nn.init.xavier_uniform(m.weight) |
| 57 | + m.bias.data.fill_(0.01) |
| 58 | + |
| 59 | + def forward_once(self, x): |
| 60 | + output = self.resnet(x) |
| 61 | + output = output.view(output.size()[0], -1) |
| 62 | + return output |
| 63 | + |
| 64 | + def forward(self, input1, input2): |
| 65 | + # get two images' features |
| 66 | + output1 = self.forward_once(input1) |
| 67 | + output2 = self.forward_once(input2) |
| 68 | + |
| 69 | + # concatenate both images' features |
| 70 | + output = torch.cat((output1, output2), 1) |
| 71 | + |
| 72 | + # pass the concatenation to the linear layers |
| 73 | + output = self.fc(output) |
| 74 | + |
| 75 | + # pass the out of the linear layers to sigmoid layer |
| 76 | + output = self.sigmoid(output) |
| 77 | + |
| 78 | + return output |
| 79 | + |
| 80 | +class APP_MATCHER(Dataset): |
| 81 | + def __init__(self, root, train, download=False): |
| 82 | + super(APP_MATCHER, self).__init__() |
| 83 | + |
| 84 | + # get MNIST dataset |
| 85 | + self.dataset = datasets.MNIST(root, train=train, download=download) |
| 86 | + |
| 87 | + # as `self.dataset.data`'s shape is (Nx28x28), where N is the number of |
| 88 | + # examples in MNIST dataset, a single example has the dimensions of |
| 89 | + # (28x28) for (WxH), where W and H are the width and the height of the image. |
| 90 | + # However, every example should have (CxWxH) dimensions where C is the number |
| 91 | + # of channels to be passed to the network. As MNIST contains gray-scale images, |
| 92 | + # we add an additional dimension to corresponds to the number of channels. |
| 93 | + self.data = self.dataset.data.unsqueeze(1).clone() |
| 94 | + |
| 95 | + self.group_examples() |
| 96 | + |
| 97 | + def group_examples(self): |
| 98 | + """ |
| 99 | + To ease the accessibility of data based on the class, we will use `group_examples` to group |
| 100 | + examples based on class. |
| 101 | + |
| 102 | + Every key in `grouped_examples` corresponds to a class in MNIST dataset. For every key in |
| 103 | + `grouped_examples`, every value will conform to all of the indices for the MNIST |
| 104 | + dataset examples that correspond to that key. |
| 105 | + """ |
| 106 | + |
| 107 | + # get the targets from MNIST dataset |
| 108 | + np_arr = np.array(self.dataset.targets.clone()) |
| 109 | + |
| 110 | + # group examples based on class |
| 111 | + self.grouped_examples = {} |
| 112 | + for i in range(0,10): |
| 113 | + self.grouped_examples[i] = np.where((np_arr==i))[0] |
| 114 | + |
| 115 | + def __len__(self): |
| 116 | + return self.data.shape[0] |
| 117 | + |
| 118 | + def __getitem__(self, index): |
| 119 | + """ |
| 120 | + For every example, we will select two images. There are two cases, |
| 121 | + positive and negative examples. For positive examples, we will have two |
| 122 | + images from the same class. For negative examples, we will have two images |
| 123 | + from different classes. |
| 124 | +
|
| 125 | + Given an index, if the index is even, we will pick the second image from the same class, |
| 126 | + but it won't be the same image we chose for the first class. This is used to ensure the positive |
| 127 | + example isn't trivial as the network would easily distinguish the similarity between same images. However, |
| 128 | + if the network were given two different images from the same class, the network will need to learn |
| 129 | + the similarity between two different images representing the same class. If the index is odd, we will |
| 130 | + pick the second image from a different class than the first image. |
| 131 | + """ |
| 132 | + |
| 133 | + # pick some random class for the first image |
| 134 | + selected_class = random.randint(0, 9) |
| 135 | + |
| 136 | + # pick a random index for the first image in the grouped indices based of the label |
| 137 | + # of the class |
| 138 | + random_index_1 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) |
| 139 | + |
| 140 | + # pick the index to get the first image |
| 141 | + index_1 = self.grouped_examples[selected_class][random_index_1] |
| 142 | + |
| 143 | + # get the first image |
| 144 | + image_1 = self.data[index_1].clone().float() |
| 145 | + |
| 146 | + # same class |
| 147 | + if index % 2 == 0: |
| 148 | + # pick a random index for the second image |
| 149 | + random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) |
| 150 | + |
| 151 | + # ensure that the index of the second image isn't the same as the first image |
| 152 | + while random_index_2 == random_index_1: |
| 153 | + random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) |
| 154 | + |
| 155 | + # pick the index to get the second image |
| 156 | + index_2 = self.grouped_examples[selected_class][random_index_2] |
| 157 | + |
| 158 | + # get the second image |
| 159 | + image_2 = self.data[index_2].clone().float() |
| 160 | + |
| 161 | + # set the label for this example to be positive (1) |
| 162 | + target = torch.tensor(1, dtype=torch.float) |
| 163 | + |
| 164 | + # different class |
| 165 | + else: |
| 166 | + # pick a random class |
| 167 | + other_selected_class = random.randint(0, 9) |
| 168 | + |
| 169 | + # ensure that the class of the second image isn't the same as the first image |
| 170 | + while other_selected_class == selected_class: |
| 171 | + other_selected_class = random.randint(0, 9) |
| 172 | + |
| 173 | + |
| 174 | + # pick a random index for the second image in the grouped indices based of the label |
| 175 | + # of the class |
| 176 | + random_index_2 = random.randint(0, self.grouped_examples[other_selected_class].shape[0]-1) |
| 177 | + |
| 178 | + # pick the index to get the second image |
| 179 | + index_2 = self.grouped_examples[other_selected_class][random_index_2] |
| 180 | + |
| 181 | + # get the second image |
| 182 | + image_2 = self.data[index_2].clone().float() |
| 183 | + |
| 184 | + # set the label for this example to be negative (0) |
| 185 | + target = torch.tensor(0, dtype=torch.float) |
| 186 | + |
| 187 | + return image_1, image_2, target |
| 188 | + |
| 189 | + |
| 190 | +def train(args, model, device, train_loader, optimizer, epoch): |
| 191 | + model.train() |
| 192 | + |
| 193 | + # we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. |
| 194 | + criterion = nn.BCELoss() |
| 195 | + |
| 196 | + for batch_idx, (images_1, images_2, targets) in enumerate(train_loader): |
| 197 | + images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) |
| 198 | + optimizer.zero_grad() |
| 199 | + outputs = model(images_1, images_2).squeeze() |
| 200 | + loss = criterion(outputs, targets) |
| 201 | + loss.backward() |
| 202 | + optimizer.step() |
| 203 | + if batch_idx % args.log_interval == 0: |
| 204 | + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
| 205 | + epoch, batch_idx * len(images_1), len(train_loader.dataset), |
| 206 | + 100. * batch_idx / len(train_loader), loss.item())) |
| 207 | + if args.dry_run: |
| 208 | + break |
| 209 | + |
| 210 | + |
| 211 | +def test(model, device, test_loader): |
| 212 | + model.eval() |
| 213 | + test_loss = 0 |
| 214 | + correct = 0 |
| 215 | + |
| 216 | + # we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. |
| 217 | + criterion = nn.BCELoss() |
| 218 | + |
| 219 | + with torch.no_grad(): |
| 220 | + for (images_1, images_2, targets) in test_loader: |
| 221 | + images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) |
| 222 | + outputs = model(images_1, images_2).squeeze() |
| 223 | + test_loss += criterion(outputs, targets).sum().item() # sum up batch loss |
| 224 | + pred = torch.where(outputs > 0.5, 1, 0) # get the index of the max log-probability |
| 225 | + correct += pred.eq(targets.view_as(pred)).sum().item() |
| 226 | + |
| 227 | + test_loss /= len(test_loader.dataset) |
| 228 | + |
| 229 | + # for the 1st epoch, the average loss is 0.0001 and the accuracy 97-98% |
| 230 | + # using default settings. After completing the 10th epoch, the average |
| 231 | + # loss is 0.0000 and the accuracy 99.5-100% using default settings. |
| 232 | + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
| 233 | + test_loss, correct, len(test_loader.dataset), |
| 234 | + 100. * correct / len(test_loader.dataset))) |
| 235 | + |
| 236 | + |
| 237 | +def main(): |
| 238 | + # Training settings |
| 239 | + parser = argparse.ArgumentParser(description='PyTorch Siamese network Example') |
| 240 | + parser.add_argument('--batch-size', type=int, default=64, metavar='N', |
| 241 | + help='input batch size for training (default: 64)') |
| 242 | + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', |
| 243 | + help='input batch size for testing (default: 1000)') |
| 244 | + parser.add_argument('--epochs', type=int, default=14, metavar='N', |
| 245 | + help='number of epochs to train (default: 14)') |
| 246 | + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', |
| 247 | + help='learning rate (default: 1.0)') |
| 248 | + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', |
| 249 | + help='Learning rate step gamma (default: 0.7)') |
| 250 | + parser.add_argument('--no-cuda', action='store_true', default=False, |
| 251 | + help='disables CUDA training') |
| 252 | + parser.add_argument('--dry-run', action='store_true', default=False, |
| 253 | + help='quickly check a single pass') |
| 254 | + parser.add_argument('--seed', type=int, default=1, metavar='S', |
| 255 | + help='random seed (default: 1)') |
| 256 | + parser.add_argument('--log-interval', type=int, default=10, metavar='N', |
| 257 | + help='how many batches to wait before logging training status') |
| 258 | + parser.add_argument('--save-model', action='store_true', default=False, |
| 259 | + help='For Saving the current Model') |
| 260 | + args = parser.parse_args() |
| 261 | + |
| 262 | + use_cuda = not args.no_cuda and torch.cuda.is_available() |
| 263 | + |
| 264 | + torch.manual_seed(args.seed) |
| 265 | + |
| 266 | + device = torch.device("cuda" if use_cuda else "cpu") |
| 267 | + |
| 268 | + train_kwargs = {'batch_size': args.batch_size} |
| 269 | + test_kwargs = {'batch_size': args.test_batch_size} |
| 270 | + if use_cuda: |
| 271 | + cuda_kwargs = {'num_workers': 1, |
| 272 | + 'pin_memory': True, |
| 273 | + 'shuffle': True} |
| 274 | + train_kwargs.update(cuda_kwargs) |
| 275 | + test_kwargs.update(cuda_kwargs) |
| 276 | + |
| 277 | + train_dataset = APP_MATCHER('../data', train=True, download=True) |
| 278 | + test_dataset = APP_MATCHER('../data', train=False) |
| 279 | + train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) |
| 280 | + test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) |
| 281 | + |
| 282 | + model = SiameseNetwork().to(device) |
| 283 | + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) |
| 284 | + |
| 285 | + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) |
| 286 | + for epoch in range(1, args.epochs + 1): |
| 287 | + train(args, model, device, train_loader, optimizer, epoch) |
| 288 | + test(model, device, test_loader) |
| 289 | + scheduler.step() |
| 290 | + |
| 291 | + if args.save_model: |
| 292 | + torch.save(model.state_dict(), "siamese_network.pt") |
| 293 | + |
| 294 | + |
| 295 | +if __name__ == '__main__': |
| 296 | + main() |
0 commit comments