-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Faster R-CNN (WIP) #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
10caef7
Changes from yesterday
fmassa 55b2bb0
Seems to work
fmassa faa3b4e
Change generator
fmassa f2e9248
fast rcnn
fmassa 3b3f1ae
Starting to prototype faster rcnn
fmassa c653216
rpn runs
fmassa 3bee8e6
frcnn runs
fmassa 22e7696
updating
fmassa 5e71e6c
A bit of organization
fmassa 9e65a2f
Organization
fmassa e119672
Rename
fmassa 4058094
rename
fmassa a0061e8
Cleaning up a bit
fmassa cfb643f
Reduce default learning rate
fmassa e36a936
Fixes
fmassa e73ee53
Removing unnecessary files from tree
fmassa 79c2402
Rename
fmassa d8d378c
minor changes
fmassa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Faster R-CNN code example | ||
|
||
```python | ||
python main.py PATH_TO_DATASET | ||
``` | ||
|
||
## Things to add/change/consider | ||
* where to handle the image scaling. Need to scale the annotations, and also RPN filters the minimum size wrt the original image size, and not the scaled image | ||
* should image scaling be handled in FasterRCNN class? | ||
* properly supporting flipping | ||
* best way to handle different parameters in RPN/FRCNN for train/eval modes | ||
* uniformize Variables, they should be provided by the user and not processed by FasterRCNN/RPN classes | ||
* general code cleanup, lots of torch/numpy mixture | ||
* should I use a general config file? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
import numpy as np | ||
import numpy.random as npr | ||
|
||
from utils import \ | ||
bbox_transform, bbox_transform_inv, clip_boxes, bbox_overlaps | ||
|
||
from utils import to_var as _tovar | ||
|
||
# should handle multiple scales, how? | ||
class FasterRCNN(nn.Container): | ||
|
||
def __init__(self, | ||
features, pooler, | ||
classifier, rpn, | ||
batch_size=128, fg_fraction=0.25, | ||
fg_threshold=0.5, bg_threshold=None, | ||
num_classes=21): | ||
super(FasterRCNN, self).__init__() | ||
self.features = features | ||
self.roi_pooling = pooler | ||
self.rpn = rpn | ||
self.classifier = classifier | ||
|
||
self.batch_size = batch_size | ||
self.fg_fraction = fg_fraction | ||
self.fg_threshold = fg_threshold | ||
if bg_threshold is None: | ||
bg_threshold = (0, 0.5) | ||
self.bg_threshold = bg_threshold | ||
self._num_classes = num_classes | ||
|
||
# should it support batched images ? | ||
def forward(self, x): | ||
#if self.training is True: | ||
if isinstance(x, tuple): | ||
im, gt = x | ||
else: | ||
im = x | ||
gt = None | ||
|
||
assert im.size(0) == 1, 'only single element batches supported' | ||
|
||
feats = self.features(_tovar(im)) | ||
|
||
roi_boxes, rpn_prob, rpn_loss = self.rpn(im, feats, gt) | ||
|
||
#if self.training is True: | ||
if gt is not None: | ||
# append gt boxes and sample fg / bg boxes | ||
# proposal_target-layer.py | ||
all_rois, frcnn_labels, roi_boxes, frcnn_bbox_targets = self.frcnn_targets(roi_boxes, im, gt) | ||
|
||
# r-cnn | ||
regions = self.roi_pooling(feats, roi_boxes) | ||
scores, bbox_pred = self.classifier(regions) | ||
|
||
boxes = self.bbox_reg(roi_boxes, bbox_pred, im) | ||
|
||
# apply cls + bbox reg loss here | ||
#if self.training is True: | ||
if gt is not None: | ||
frcnn_loss = self.frcnn_loss(scores, bbox_pred, frcnn_labels, frcnn_bbox_targets) | ||
loss = frcnn_loss + rpn_loss | ||
return loss, scores, boxes | ||
|
||
return scores, boxes | ||
|
||
def frcnn_loss(self, scores, bbox_pred, labels, bbox_targets): | ||
cls_crit = nn.CrossEntropyLoss() | ||
cls_loss = cls_crit(scores, labels) | ||
|
||
reg_crit = nn.SmoothL1Loss() | ||
reg_loss = reg_crit(bbox_pred, bbox_targets) | ||
|
||
loss = cls_loss + reg_loss | ||
return loss | ||
|
||
def frcnn_targets(self, all_rois, im, gt): | ||
all_rois = all_rois.data.numpy() | ||
gt_boxes = gt['boxes'].numpy() | ||
gt_labels = np.array(gt['gt_classes']) | ||
#zeros = np.zeros((gt_boxes.shape[0], 1), dtype=gt_boxes.dtype) | ||
#all_rois = np.vstack( | ||
# (all_rois, np.hstack((zeros, gt_boxes[:, :-1]))) | ||
#) | ||
all_rois = np.vstack((all_rois, gt_boxes)) | ||
zeros = np.zeros((all_rois.shape[0], 1), dtype=all_rois.dtype) | ||
all_rois = np.hstack((zeros, all_rois)) | ||
|
||
num_images = 1 | ||
rois_per_image = self.batch_size / num_images | ||
fg_rois_per_image = np.round(self.fg_fraction * rois_per_image) | ||
|
||
# Sample rois with classification labels and bounding box regression | ||
# targets | ||
labels, rois, bbox_targets = _sample_rois(self, | ||
all_rois, gt_boxes, gt_labels, fg_rois_per_image, | ||
rois_per_image, self._num_classes) | ||
|
||
return _tovar((all_rois, labels, rois, bbox_targets)) | ||
|
||
def bbox_reg(self, boxes, box_deltas, im): | ||
boxes = boxes.data[:,1:].numpy() | ||
box_deltas = box_deltas.data.numpy() | ||
pred_boxes = bbox_transform_inv(boxes, box_deltas) | ||
pred_boxes = clip_boxes(pred_boxes, im.size()[-2:]) | ||
return _tovar(pred_boxes) | ||
|
||
def _get_bbox_regression_labels(bbox_target_data, num_classes): | ||
"""Bounding-box regression targets (bbox_target_data) are stored in a | ||
compact form N x (class, tx, ty, tw, th) | ||
This function expands those targets into the 4-of-4*K representation used | ||
by the network (i.e. only one class has non-zero targets). | ||
Returns: | ||
bbox_target (ndarray): N x 4K blob of regression targets | ||
bbox_inside_weights (ndarray): N x 4K blob of loss weights | ||
""" | ||
|
||
clss = bbox_target_data[:, 0] | ||
bbox_targets = np.zeros((clss.size, 4 * num_classes), dtype=np.float32) | ||
bbox_inside_weights = np.zeros(bbox_targets.shape, dtype=np.float32) | ||
inds = np.where(clss > 0)[0] | ||
for ind in inds: | ||
cls = clss[ind] | ||
start = 4 * cls | ||
end = start + 4 | ||
bbox_targets[ind, start:end] = bbox_target_data[ind, 1:] | ||
return bbox_targets | ||
|
||
|
||
def _compute_targets(ex_rois, gt_rois, labels): | ||
"""Compute bounding-box regression targets for an image.""" | ||
|
||
assert ex_rois.shape[0] == gt_rois.shape[0] | ||
assert ex_rois.shape[1] == 4 | ||
assert gt_rois.shape[1] == 4 | ||
|
||
targets = bbox_transform(ex_rois, gt_rois) | ||
if False: #cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED: | ||
# Optionally normalize targets by a precomputed mean and stdev | ||
targets = ((targets - np.array(cfg.TRAIN.BBOX_NORMALIZE_MEANS)) | ||
/ np.array(cfg.TRAIN.BBOX_NORMALIZE_STDS)) | ||
return np.hstack( | ||
(labels[:, np.newaxis], targets)).astype(np.float32, copy=False) | ||
|
||
def _sample_rois(self, all_rois, gt_boxes, gt_labels, fg_rois_per_image, rois_per_image, num_classes): | ||
"""Generate a random sample of RoIs comprising foreground and background | ||
examples. | ||
""" | ||
# overlaps: (rois x gt_boxes) | ||
overlaps = bbox_overlaps( | ||
np.ascontiguousarray(all_rois[:, 1:5], dtype=np.float), | ||
np.ascontiguousarray(gt_boxes[:, :4], dtype=np.float)) | ||
overlaps = overlaps.numpy() | ||
gt_assignment = overlaps.argmax(axis=1) | ||
max_overlaps = overlaps.max(axis=1) | ||
#labels = gt_boxes[gt_assignment, 4] | ||
labels = gt_labels[gt_assignment] | ||
|
||
# Select foreground RoIs as those with >= FG_THRESH overlap | ||
fg_inds = np.where(max_overlaps >= self.fg_threshold)[0] | ||
# Guard against the case when an image has fewer than fg_rois_per_image | ||
# foreground RoIs | ||
fg_rois_per_this_image = min(fg_rois_per_image, fg_inds.size) | ||
# Sample foreground regions without replacement | ||
if fg_inds.size > 0: | ||
fg_inds = npr.choice(fg_inds, size=fg_rois_per_this_image, replace=False) | ||
|
||
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI) | ||
bg_inds = np.where((max_overlaps < self.bg_threshold[1]) & | ||
(max_overlaps >= self.bg_threshold[0]))[0] | ||
# Compute number of background RoIs to take from this image (guarding | ||
# against there being fewer than desired) | ||
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image | ||
bg_rois_per_this_image = min(bg_rois_per_this_image, bg_inds.size) | ||
# Sample background regions without replacement | ||
if bg_inds.size > 0: | ||
bg_inds = npr.choice(bg_inds, size=bg_rois_per_this_image, replace=False) | ||
|
||
# The indices that we're selecting (both fg and bg) | ||
keep_inds = np.append(fg_inds, bg_inds) | ||
# Select sampled values from various arrays: | ||
labels = labels[keep_inds] | ||
# Clamp labels for the background RoIs to 0 | ||
labels[fg_rois_per_this_image:] = 0 | ||
rois = all_rois[keep_inds] | ||
|
||
bbox_target_data = _compute_targets( | ||
rois[:, 1:5], gt_boxes[gt_assignment[keep_inds], :4], labels) | ||
|
||
bbox_targets = \ | ||
_get_bbox_regression_labels(bbox_target_data, num_classes) | ||
|
||
return labels, rois, bbox_targets | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# -------------------------------------------------------- | ||
# Faster R-CNN | ||
# Copyright (c) 2015 Microsoft | ||
# Licensed under The MIT License [see LICENSE for details] | ||
# Written by Ross Girshick and Sean Bell | ||
# -------------------------------------------------------- | ||
|
||
import numpy as np | ||
|
||
# Verify that we compute the same anchors as Shaoqing's matlab implementation: | ||
# | ||
# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat | ||
# >> anchors | ||
# | ||
# anchors = | ||
# | ||
# -83 -39 100 56 | ||
# -175 -87 192 104 | ||
# -359 -183 376 200 | ||
# -55 -55 72 72 | ||
# -119 -119 136 136 | ||
# -247 -247 264 264 | ||
# -35 -79 52 96 | ||
# -79 -167 96 184 | ||
# -167 -343 184 360 | ||
|
||
#array([[ -83., -39., 100., 56.], | ||
# [-175., -87., 192., 104.], | ||
# [-359., -183., 376., 200.], | ||
# [ -55., -55., 72., 72.], | ||
# [-119., -119., 136., 136.], | ||
# [-247., -247., 264., 264.], | ||
# [ -35., -79., 52., 96.], | ||
# [ -79., -167., 96., 184.], | ||
# [-167., -343., 184., 360.]]) | ||
|
||
def generate_anchors(base_size=16, ratios=[0.5, 1, 2], | ||
scales=2**np.arange(3, 6)): | ||
""" | ||
Generate anchor (reference) windows by enumerating aspect ratios X | ||
scales wrt a reference (0, 0, 15, 15) window. | ||
""" | ||
|
||
base_anchor = np.array([1, 1, base_size, base_size]) - 1 | ||
ratio_anchors = _ratio_enum(base_anchor, ratios) | ||
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales) | ||
for i in xrange(ratio_anchors.shape[0])]) | ||
return anchors | ||
|
||
def _whctrs(anchor): | ||
""" | ||
Return width, height, x center, and y center for an anchor (window). | ||
""" | ||
|
||
w = anchor[2] - anchor[0] + 1 | ||
h = anchor[3] - anchor[1] + 1 | ||
x_ctr = anchor[0] + 0.5 * (w - 1) | ||
y_ctr = anchor[1] + 0.5 * (h - 1) | ||
return w, h, x_ctr, y_ctr | ||
|
||
def _mkanchors(ws, hs, x_ctr, y_ctr): | ||
""" | ||
Given a vector of widths (ws) and heights (hs) around a center | ||
(x_ctr, y_ctr), output a set of anchors (windows). | ||
""" | ||
|
||
ws = ws[:, np.newaxis] | ||
hs = hs[:, np.newaxis] | ||
anchors = np.hstack((x_ctr - 0.5 * (ws - 1), | ||
y_ctr - 0.5 * (hs - 1), | ||
x_ctr + 0.5 * (ws - 1), | ||
y_ctr + 0.5 * (hs - 1))) | ||
return anchors | ||
|
||
def _ratio_enum(anchor, ratios): | ||
""" | ||
Enumerate a set of anchors for each aspect ratio wrt an anchor. | ||
""" | ||
|
||
w, h, x_ctr, y_ctr = _whctrs(anchor) | ||
size = w * h | ||
size_ratios = size / ratios | ||
ws = np.round(np.sqrt(size_ratios)) | ||
hs = np.round(ws * ratios) | ||
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) | ||
return anchors | ||
|
||
def _scale_enum(anchor, scales): | ||
""" | ||
Enumerate a set of anchors for each scale wrt an anchor. | ||
""" | ||
|
||
w, h, x_ctr, y_ctr = _whctrs(anchor) | ||
ws = w * scales | ||
hs = h * scales | ||
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) | ||
return anchors | ||
|
||
if __name__ == '__main__': | ||
import time | ||
t = time.time() | ||
a = generate_anchors() | ||
print time.time() - t | ||
print a | ||
from IPython import embed; embed() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This comment was marked as off-topic.
Sorry, something went wrong.