-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Faster R-CNN (WIP) #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This looks good. I think the most important thing is to get training (and evaluation) working and then worry more about design and cleaning up the code. Some thoughts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we have to include the license file if we're including the original code?
boxes[:, 2] = width - oldx1 - 1 | ||
return boxes | ||
|
||
class TransformVOCDetectionAnnotation(object): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if isinstance(x, np.ndarray): | ||
return Variable(torch.from_numpy(x), requires_grad=False) | ||
elif torch.is_tensor(x): | ||
return Variable(x, requires_grad=True) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
cls_crit = nn.CrossEntropyLoss() | ||
cls_loss = cls_crit(scores, labels) | ||
|
||
reg_crit = nn.SmoothL1Loss() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
from torch.autograd.function import Function | ||
from torch._thnn import type2backend | ||
|
||
class AdaptiveMaxPool2d(Function): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
# I need to know the original image size (or have the scaling factor) | ||
def get_roi_boxes(self, anchors, rpn_map, rpn_bbox_deltas, im): | ||
# TODO fix this!!! |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
class_to_ind = dict(zip(cls, range(len(cls)))) | ||
|
||
|
||
train = VOCDetection('/home/francisco/work/datasets/VOCdevkit/', 'train', |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Also, there are a lot of commented statements that should be removed before merging |
@colesbury @apaszke thanks for your comments. I think the I initially wanted to keep the number of files small (as it is an example code), so I fused a number of things together, but that's probably a poor design choice. I will validate that the basic code is working as expected by performing a training/evaluation and then I'll focus on getting a refactoring of this PR. I'll get back to it on Monday, I've a trip to do in Rio tomorrow :) |
Ugh nvm, for some reason I haven't noticed that it's WIP... Have a nice trip! 😃 |
Have you tested https://github.com/longcw/faster_rcnn_pytorch? |
@bhack I haven't, and I didn't have the time to finish this properly. |
@KaimingHe In the plan of releasing Mask r-cnn there will be also a faster-rcnn pytorch implementation merged in this repository? |
A TF WIP Mask R-CNN effort was started at https://github.com/CharlesShang/FastMaskRCNN. Actually there is still no public reference implementation of the paper so we will see what kind of accurancy can be reproduced. |
This WIP PR implements Faster R-CNN based on https://github.com/rbgirshick/py-faster-rcnn .
There are lots of copy-paste from Girshick's repo, and a few files were copied as is.
I decided to fuse a number of operations together, but I'm not sure anymore it's the best way to go.
A few comments:
FasterRCNN
andRPN
. RPN implements the region proposal network, and faster r-cnn wraps almost all the detection logic (which might hurt flexibility at some point).cfg
dict containing all the parameters, and instead pass the parameters as the constructors of the classes. But different train/test parameters still need to be defined in the model.im_info[2]
), necessary for properly pruning small boxes in the RPN.I'm opening this PR to get some feed-back on the general structure of the code.
Overall, I'm starting to think that it might be better to use the same structure as the one in py-faster-rcnn from Girshick repo.
Code cleanup and organization is required.